init
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import chat, health, pomodoro
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1")
|
||||
api_router.include_router(health.router, tags=["health"])
|
||||
api_router.include_router(chat.router, prefix="/chat", tags=["chat"])
|
||||
api_router.include_router(pomodoro.router, prefix="/pomodoro", tags=["pomodoro"])
|
||||
@@ -0,0 +1,55 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.schemas import MessageCreate, MessageOut, SessionCreate, SessionDetailOut, SessionOut
|
||||
from app.chat.service import ChatService
|
||||
from app.db.base import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=SessionOut)
|
||||
def create_session(payload: SessionCreate, db: Session = Depends(get_db)) -> SessionOut:
|
||||
service = ChatService(db)
|
||||
return service.create_session(title=payload.title)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[SessionOut])
|
||||
def list_sessions(db: Session = Depends(get_db)) -> list[SessionOut]:
|
||||
service = ChatService(db)
|
||||
return service.list_sessions()
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionDetailOut)
|
||||
def get_session(session_id: int, db: Session = Depends(get_db)) -> SessionDetailOut:
|
||||
service = ChatService(db)
|
||||
session = service.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
def delete_session(session_id: int, db: Session = Depends(get_db)) -> dict[str, bool]:
|
||||
service = ChatService(db)
|
||||
if not service.delete_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/messages")
|
||||
async def send_message(
|
||||
session_id: int,
|
||||
payload: MessageCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
service = ChatService(db)
|
||||
if not service.get_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
async def event_stream():
|
||||
async for chunk in service.stream_response(session_id, payload.content):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
@@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
@@ -0,0 +1,60 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.schemas import PomodoroStart, PomodoroStop
|
||||
from app.db.base import get_db
|
||||
from app.pomodoro.service import PomodoroService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _handle_value_error(exc: ValueError) -> HTTPException:
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
def get_status(db: Session = Depends(get_db)) -> dict:
|
||||
return PomodoroService(db).get_status()
|
||||
|
||||
|
||||
@router.post("/start")
|
||||
def start_pomodoro(payload: PomodoroStart, db: Session = Depends(get_db)) -> dict:
|
||||
try:
|
||||
return PomodoroService(db).start(
|
||||
duration_min=payload.duration_min,
|
||||
task_note=payload.task_note,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _handle_value_error(exc) from exc
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
def pause_pomodoro(db: Session = Depends(get_db)) -> dict:
|
||||
try:
|
||||
return PomodoroService(db).pause()
|
||||
except ValueError as exc:
|
||||
raise _handle_value_error(exc) from exc
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
def resume_pomodoro(db: Session = Depends(get_db)) -> dict:
|
||||
try:
|
||||
return PomodoroService(db).resume()
|
||||
except ValueError as exc:
|
||||
raise _handle_value_error(exc) from exc
|
||||
|
||||
|
||||
@router.post("/stop")
|
||||
def stop_pomodoro(payload: PomodoroStop, db: Session = Depends(get_db)) -> dict:
|
||||
try:
|
||||
return PomodoroService(db).stop(
|
||||
result=payload.result,
|
||||
completed=payload.completed,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _handle_value_error(exc) from exc
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
def get_history(limit: int = 20, db: Session = Depends(get_db)) -> list[dict]:
|
||||
return PomodoroService(db).history(limit=limit)
|
||||
@@ -0,0 +1,43 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SessionCreate(BaseModel):
|
||||
title: str = "Новый чат"
|
||||
|
||||
|
||||
class SessionOut(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class MessageOut(BaseModel):
|
||||
id: int
|
||||
role: str
|
||||
content: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SessionDetailOut(SessionOut):
|
||||
messages: list[MessageOut]
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
content: str = Field(min_length=1)
|
||||
|
||||
|
||||
class PomodoroStart(BaseModel):
|
||||
duration_min: int = Field(default=25, ge=1, le=180)
|
||||
task_note: str = ""
|
||||
|
||||
|
||||
class PomodoroStop(BaseModel):
|
||||
result: str = ""
|
||||
completed: bool = False
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.chat.service import ChatService
|
||||
|
||||
__all__ = ["ChatService"]
|
||||
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import get_settings
|
||||
from app.db.models import ChatSession, Message
|
||||
from app.llm.client import LLMClient
|
||||
from app.tools.registry import TOOL_DEFINITIONS, execute_tool
|
||||
|
||||
MAX_TOOL_ROUNDS = 5
|
||||
|
||||
|
||||
class ChatService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.llm = LLMClient()
|
||||
self.system_prompt = get_settings().load_system_prompt()
|
||||
|
||||
def list_sessions(self) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).order_by(ChatSession.updated_at.desc())
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def get_session(self, session_id: int) -> ChatSession | None:
|
||||
return self.db.get(ChatSession, session_id)
|
||||
|
||||
def create_session(self, title: str = "Новый чат") -> ChatSession:
|
||||
session = ChatSession(title=title)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: int) -> bool:
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return False
|
||||
self.db.delete(session)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": self.system_prompt}]
|
||||
for msg in session.messages:
|
||||
content = msg.content or None
|
||||
entry: dict[str, Any] = {"role": msg.role, "content": content}
|
||||
if msg.tool_calls_json:
|
||||
entry["tool_calls"] = json.loads(msg.tool_calls_json)
|
||||
if not content:
|
||||
entry["content"] = None
|
||||
if msg.role == "tool" and msg.tool_call_id:
|
||||
entry["tool_call_id"] = msg.tool_call_id
|
||||
messages.append(entry)
|
||||
return messages
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
session_id: int,
|
||||
role: str,
|
||||
content: str = "",
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
content=content,
|
||||
tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
self.db.add(message)
|
||||
session = self.get_session(session_id)
|
||||
if session and role == "user" and session.title == "Новый чат" and content:
|
||||
session.title = content[:60] + ("..." if len(content) > 60 else "")
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
async def stream_response(self, session_id: int, user_text: str) -> AsyncIterator[str]:
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
yield self._sse("error", {"message": "Session not found"})
|
||||
return
|
||||
|
||||
self._save_message(session_id, "user", user_text)
|
||||
messages = self._build_messages(session)
|
||||
|
||||
for _ in range(MAX_TOOL_ROUNDS):
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS):
|
||||
if event["type"] == "content":
|
||||
content_parts.append(event["content"])
|
||||
yield self._sse("token", {"content": event["content"]})
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
|
||||
if tool_calls:
|
||||
assistant_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": "".join(content_parts) or None,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
messages.append(assistant_msg)
|
||||
self._save_message(
|
||||
session_id,
|
||||
"assistant",
|
||||
"".join(content_parts),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
fn = tool_call["function"]
|
||||
args = LLMClient.parse_tool_arguments(fn.get("arguments", ""))
|
||||
result = execute_tool(self.db, fn["name"], args)
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": result,
|
||||
}
|
||||
messages.append(tool_message)
|
||||
self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"])
|
||||
yield self._sse("tool", {"name": fn["name"], "result": json.loads(result)})
|
||||
|
||||
continue
|
||||
|
||||
final_content = "".join(content_parts)
|
||||
if final_content:
|
||||
self._save_message(session_id, "assistant", final_content)
|
||||
|
||||
yield self._sse("done", {})
|
||||
return
|
||||
|
||||
yield self._sse("error", {"message": "Too many tool call rounds"})
|
||||
|
||||
@staticmethod
|
||||
def _sse(event: str, data: dict[str, Any]) -> str:
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
@@ -0,0 +1,38 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=(".env", "../.env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8080
|
||||
|
||||
openrouter_api_key: str = ""
|
||||
openrouter_model: str = "deepseek/deepseek-chat"
|
||||
openrouter_base_url: str = "https://openrouter.ai/api/v1"
|
||||
|
||||
database_url: str = "sqlite:///./data/assistant.db"
|
||||
cors_origins: str = "http://localhost:5173,http://localhost:8080,http://localhost:3000"
|
||||
system_prompt_path: str = "./prompts/assistant.md"
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> list[str]:
|
||||
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
|
||||
|
||||
def load_system_prompt(self) -> str:
|
||||
path = Path(self.system_prompt_path)
|
||||
if path.is_file():
|
||||
return path.read_text(encoding="utf-8")
|
||||
return "Ты домашний ИИ-ассистент. Общайся на русском."
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
@@ -0,0 +1,11 @@
|
||||
from app.db.base import Base, get_db, init_db
|
||||
from app.db.models import ChatSession, Message, PomodoroSession
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ChatSession",
|
||||
"Message",
|
||||
"PomodoroSession",
|
||||
"get_db",
|
||||
"init_db",
|
||||
]
|
||||
@@ -0,0 +1,39 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_sqlite_dir(database_url: str) -> None:
|
||||
if database_url.startswith("sqlite:///"):
|
||||
db_path = database_url.replace("sqlite:///", "", 1)
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
_ensure_sqlite_dir(settings.database_url)
|
||||
|
||||
connect_args = {"check_same_thread": False} if settings.database_url.startswith("sqlite") else {}
|
||||
engine = create_engine(settings.database_url, connect_args=connect_args)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
from app.db import models # noqa: F401
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,51 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class ChatSession(Base):
|
||||
__tablename__ = "chat_sessions"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
title: Mapped[str] = mapped_column(String(255), default="Новый чат")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
messages: Mapped[list["Message"]] = relationship(
|
||||
back_populates="session", cascade="all, delete-orphan", order_by="Message.created_at"
|
||||
)
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
session_id: Mapped[int] = mapped_column(ForeignKey("chat_sessions.id", ondelete="CASCADE"), index=True)
|
||||
role: Mapped[str] = mapped_column(String(32))
|
||||
content: Mapped[str] = mapped_column(Text, default="")
|
||||
tool_calls_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tool_call_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
session: Mapped["ChatSession"] = relationship(back_populates="messages")
|
||||
|
||||
|
||||
class PomodoroSession(Base):
|
||||
__tablename__ = "pomodoro_sessions"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="idle")
|
||||
duration_min: Mapped[int] = mapped_column(Integer, default=25)
|
||||
task_note: Mapped[str] = mapped_column(Text, default="")
|
||||
result: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
completed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
paused_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
elapsed_seconds: Mapped[int] = mapped_column(Integer, default=0)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.llm.client import LLMClient
|
||||
|
||||
__all__ = ["LLMClient"]
|
||||
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self) -> None:
|
||||
settings = get_settings()
|
||||
self.model = settings.openrouter_model
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=settings.openrouter_api_key,
|
||||
base_url=settings.openrouter_base_url,
|
||||
)
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
stream = await self.client.chat.completions.create(**kwargs)
|
||||
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
if delta.content:
|
||||
yield {"type": "content", "content": delta.content}
|
||||
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
idx = tool_call.index
|
||||
if idx not in tool_calls:
|
||||
tool_calls[idx] = {
|
||||
"id": tool_call.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
if tool_call.id:
|
||||
tool_calls[idx]["id"] = tool_call.id
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_calls[idx]["function"]["name"] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
tool_calls[idx]["function"]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if choice.finish_reason:
|
||||
if tool_calls:
|
||||
yield {"type": "tool_calls", "tool_calls": list(tool_calls.values())}
|
||||
yield {"type": "done", "finish_reason": choice.finish_reason}
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
response = await self.client.chat.completions.create(**kwargs)
|
||||
message = response.choices[0].message
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"content": message.content or "",
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if message.tool_calls:
|
||||
result["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def parse_tool_arguments(arguments: str) -> dict[str, Any]:
|
||||
if not arguments:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
@@ -0,0 +1,33 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.routes import api_router
|
||||
from app.config import get_settings
|
||||
from app.db.base import init_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
init_db()
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
settings = get_settings()
|
||||
app = FastAPI(title="Home AI Assistant", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(api_router)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.pomodoro.service import PomodoroService
|
||||
|
||||
__all__ = ["PomodoroService"]
|
||||
@@ -0,0 +1,152 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.models import PomodoroSession
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class PomodoroService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def _get_active(self) -> PomodoroSession | None:
|
||||
stmt = (
|
||||
select(PomodoroSession)
|
||||
.where(PomodoroSession.status.in_(("running", "paused")))
|
||||
.order_by(PomodoroSession.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def _elapsed(self, session: PomodoroSession) -> int:
|
||||
elapsed = session.elapsed_seconds
|
||||
if session.status == "running" and session.started_at:
|
||||
started = session.started_at
|
||||
if started.tzinfo is None:
|
||||
started = started.replace(tzinfo=timezone.utc)
|
||||
delta = _utcnow() - started
|
||||
elapsed += int(delta.total_seconds())
|
||||
return elapsed
|
||||
|
||||
def _remaining(self, session: PomodoroSession) -> int:
|
||||
total = session.duration_min * 60
|
||||
return max(0, total - self._elapsed(session))
|
||||
|
||||
def _to_status_dict(self, session: PomodoroSession | None) -> dict:
|
||||
if not session:
|
||||
return {
|
||||
"status": "idle",
|
||||
"duration_min": 25,
|
||||
"task_note": "",
|
||||
"elapsed_seconds": 0,
|
||||
"remaining_seconds": 0,
|
||||
"session_id": None,
|
||||
}
|
||||
|
||||
elapsed = self._elapsed(session)
|
||||
total = session.duration_min * 60
|
||||
remaining = max(0, total - elapsed)
|
||||
|
||||
if session.status == "running" and remaining == 0:
|
||||
session.status = "completed"
|
||||
session.finished_at = _utcnow()
|
||||
session.completed = True
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
|
||||
return {
|
||||
"status": session.status,
|
||||
"duration_min": session.duration_min,
|
||||
"task_note": session.task_note,
|
||||
"elapsed_seconds": elapsed,
|
||||
"remaining_seconds": remaining,
|
||||
"session_id": session.id,
|
||||
"started_at": session.started_at.isoformat() if session.started_at else None,
|
||||
"finished_at": session.finished_at.isoformat() if session.finished_at else None,
|
||||
}
|
||||
|
||||
def get_status(self) -> dict:
|
||||
return self._to_status_dict(self._get_active())
|
||||
|
||||
def start(self, duration_min: int = 25, task_note: str = "") -> dict:
|
||||
active = self._get_active()
|
||||
if active:
|
||||
raise ValueError("Таймер уже запущен. Сначала остановите текущую сессию.")
|
||||
|
||||
session = PomodoroSession(
|
||||
status="running",
|
||||
duration_min=duration_min,
|
||||
task_note=task_note,
|
||||
started_at=_utcnow(),
|
||||
)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return self._to_status_dict(session)
|
||||
|
||||
def pause(self) -> dict:
|
||||
session = self._get_active()
|
||||
if not session or session.status != "running":
|
||||
raise ValueError("Нет активного запущенного таймера.")
|
||||
|
||||
session.elapsed_seconds = self._elapsed(session)
|
||||
session.status = "paused"
|
||||
session.paused_at = _utcnow()
|
||||
session.started_at = None
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return self._to_status_dict(session)
|
||||
|
||||
def resume(self) -> dict:
|
||||
session = self._get_active()
|
||||
if not session or session.status != "paused":
|
||||
raise ValueError("Нет таймера на паузе.")
|
||||
|
||||
session.status = "running"
|
||||
session.started_at = _utcnow()
|
||||
session.paused_at = None
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return self._to_status_dict(session)
|
||||
|
||||
def stop(self, result: str = "", completed: bool = False) -> dict:
|
||||
session = self._get_active()
|
||||
if not session:
|
||||
raise ValueError("Нет активного таймера.")
|
||||
|
||||
session.elapsed_seconds = self._elapsed(session)
|
||||
session.status = "completed" if completed else "cancelled"
|
||||
session.result = result
|
||||
session.completed = completed
|
||||
session.finished_at = _utcnow()
|
||||
session.started_at = None
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return self._to_status_dict(session)
|
||||
|
||||
def history(self, limit: int = 20) -> list[dict]:
|
||||
stmt = (
|
||||
select(PomodoroSession)
|
||||
.where(PomodoroSession.status.in_(("completed", "cancelled")))
|
||||
.order_by(PomodoroSession.finished_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
sessions = self.db.scalars(stmt).all()
|
||||
return [
|
||||
{
|
||||
"id": s.id,
|
||||
"status": s.status,
|
||||
"duration_min": s.duration_min,
|
||||
"task_note": s.task_note,
|
||||
"result": s.result,
|
||||
"completed": s.completed,
|
||||
"elapsed_seconds": s.elapsed_seconds,
|
||||
"finished_at": s.finished_at.isoformat() if s.finished_at else None,
|
||||
}
|
||||
for s in sessions
|
||||
]
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.tools.registry import TOOL_DEFINITIONS, execute_tool
|
||||
|
||||
__all__ = ["TOOL_DEFINITIONS", "execute_tool"]
|
||||
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.pomodoro.service import PomodoroService
|
||||
|
||||
TOOL_DEFINITIONS: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_pomodoro_status",
|
||||
"description": "Получить текущий статус помидоро-таймера",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "start_pomodoro",
|
||||
"description": "Запустить помидоро-таймер",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"duration_min": {
|
||||
"type": "integer",
|
||||
"description": "Длительность в минутах, по умолчанию 25",
|
||||
},
|
||||
"task_note": {
|
||||
"type": "string",
|
||||
"description": "Над чем работаем в этой сессии",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "stop_pomodoro",
|
||||
"description": "Остановить текущий помидоро-таймер",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
"description": "Краткий отчёт о том, что сделано",
|
||||
},
|
||||
"completed": {
|
||||
"type": "boolean",
|
||||
"description": "True если задача полностью завершена",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_pomodoro_history",
|
||||
"description": "Получить историю завершённых помидоро-сессий",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Сколько последних сессий вернуть, по умолчанию 10",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def execute_tool(db: Session, name: str, arguments: dict[str, Any]) -> str:
|
||||
service = PomodoroService(db)
|
||||
|
||||
try:
|
||||
if name == "get_pomodoro_status":
|
||||
result = service.get_status()
|
||||
elif name == "start_pomodoro":
|
||||
result = service.start(
|
||||
duration_min=arguments.get("duration_min", 25),
|
||||
task_note=arguments.get("task_note", ""),
|
||||
)
|
||||
elif name == "stop_pomodoro":
|
||||
result = service.stop(
|
||||
result=arguments.get("result", ""),
|
||||
completed=arguments.get("completed", False),
|
||||
)
|
||||
elif name == "get_pomodoro_history":
|
||||
result = service.history(limit=arguments.get("limit", 10))
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown tool: {name}"}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except ValueError as exc:
|
||||
return json.dumps({"error": str(exc)}, ensure_ascii=False)
|
||||
Reference in New Issue
Block a user