220 lines
8.6 KiB
Python
220 lines
8.6 KiB
Python
import json
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import get_settings
|
|
from app.character.service import CharacterService
|
|
from app.chat.notices import (
|
|
POMODORO_TOOL_NAMES,
|
|
format_pomodoro_context,
|
|
format_tool_notice,
|
|
)
|
|
from app.fitness.context import format_fitness_context, get_fitness_snapshot
|
|
from app.homelab.context import format_datetime_context
|
|
from app.homelab.openmeteo import format_weather_snapshot
|
|
from app.memory.context import (
|
|
format_identity_hint,
|
|
format_memory_context,
|
|
get_memory_snapshot,
|
|
)
|
|
from app.memory.extract import extract_after_turn
|
|
from app.projects.context import format_projects_context, get_projects_snapshot
|
|
from app.shopping.context import format_shopping_context, get_shopping_snapshot
|
|
from app.db.models import ChatSession, Message
|
|
from app.llm.client import LLMClient
|
|
from app.pomodoro.service import PomodoroService
|
|
from app.tools.registry import TOOL_DEFINITIONS, execute_tool
|
|
|
|
MAX_TOOL_ROUNDS = 5
|
|
MAX_HISTORY_MESSAGES = 40
|
|
|
|
|
|
class ChatService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.llm = LLMClient()
|
|
self.character = CharacterService()
|
|
|
|
def list_sessions(self) -> list[ChatSession]:
|
|
stmt = select(ChatSession).order_by(ChatSession.updated_at.desc())
|
|
return list(self.db.scalars(stmt).all())
|
|
|
|
def get_session(self, session_id: int) -> ChatSession | None:
|
|
return self.db.get(ChatSession, session_id)
|
|
|
|
def create_session(self, title: str = "Новый чат") -> ChatSession:
|
|
session = ChatSession(title=title)
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
self.db.refresh(session)
|
|
return session
|
|
|
|
def delete_session(self, session_id: int) -> bool:
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
return False
|
|
self.db.delete(session)
|
|
self.db.commit()
|
|
return True
|
|
|
|
def _build_system_prompt(self, session_id: int | None = None) -> str:
|
|
status = PomodoroService(self.db).get_status()
|
|
memory_snapshot = get_memory_snapshot(self.db, session_id)
|
|
fitness_snapshot = get_fitness_snapshot(self.db)
|
|
shopping_snapshot = get_shopping_snapshot(self.db)
|
|
projects_snapshot = get_projects_snapshot(self.db)
|
|
return (
|
|
f"{self.character.get_system_prompt()}\n\n"
|
|
f"{format_datetime_context(self.db)}\n\n"
|
|
f"{format_memory_context(memory_snapshot)}\n\n"
|
|
f"{format_fitness_context(fitness_snapshot)}\n\n"
|
|
f"{format_shopping_context(shopping_snapshot)}\n\n"
|
|
f"{format_weather_snapshot()}\n\n"
|
|
f"{format_pomodoro_context(status)}\n\n"
|
|
f"{format_projects_context(projects_snapshot)}"
|
|
)
|
|
|
|
def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]:
|
|
system_prompt = self._build_system_prompt(session.id)
|
|
all_chat = [m for m in session.messages if m.role != "notice"]
|
|
last_user = next((m.content for m in reversed(all_chat) if m.role == "user"), "")
|
|
if last_user:
|
|
memory_snapshot = get_memory_snapshot(self.db, session.id)
|
|
identity_hint = format_identity_hint(memory_snapshot, last_user)
|
|
if identity_hint:
|
|
system_prompt += f"\n\n{identity_hint}"
|
|
if len(all_chat) > MAX_HISTORY_MESSAGES:
|
|
system_prompt += (
|
|
f"\n\n[История чата: в контексте последние {MAX_HISTORY_MESSAGES} "
|
|
f"из {len(all_chat)} сообщений. Раннее — в сводке сессии, если сохранена.]"
|
|
)
|
|
messages: list[dict[str, Any]] = [
|
|
{"role": "system", "content": system_prompt}
|
|
]
|
|
chat_messages = all_chat[-MAX_HISTORY_MESSAGES:] if len(all_chat) > MAX_HISTORY_MESSAGES else all_chat
|
|
|
|
for msg in chat_messages:
|
|
content = msg.content or None
|
|
entry: dict[str, Any] = {"role": msg.role, "content": content}
|
|
if msg.tool_calls_json:
|
|
entry["tool_calls"] = json.loads(msg.tool_calls_json)
|
|
if not content:
|
|
entry["content"] = None
|
|
if msg.role == "tool" and msg.tool_call_id:
|
|
entry["tool_call_id"] = msg.tool_call_id
|
|
messages.append(entry)
|
|
return messages
|
|
|
|
def _save_message(
|
|
self,
|
|
session_id: int,
|
|
role: str,
|
|
content: str = "",
|
|
tool_calls: list[dict[str, Any]] | None = None,
|
|
tool_call_id: str | None = None,
|
|
) -> Message:
|
|
message = Message(
|
|
session_id=session_id,
|
|
role=role,
|
|
content=content,
|
|
tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None,
|
|
tool_call_id=tool_call_id,
|
|
)
|
|
self.db.add(message)
|
|
session = self.get_session(session_id)
|
|
if session and role == "user" and session.title == "Новый чат" and content:
|
|
session.title = content[:60] + ("..." if len(content) > 60 else "")
|
|
self.db.commit()
|
|
self.db.refresh(message)
|
|
return message
|
|
|
|
async def stream_response(self, session_id: int, user_text: str) -> AsyncIterator[str]:
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
yield self._sse("error", {"message": "Session not found"})
|
|
return
|
|
|
|
self._save_message(session_id, "user", user_text)
|
|
messages = self._build_messages(session)
|
|
|
|
for _ in range(MAX_TOOL_ROUNDS):
|
|
content_parts: list[str] = []
|
|
tool_calls: list[dict[str, Any]] = []
|
|
|
|
async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS):
|
|
if event["type"] == "content":
|
|
content_parts.append(event["content"])
|
|
yield self._sse("token", {"content": event["content"]})
|
|
elif event["type"] == "tool_calls":
|
|
tool_calls = event["tool_calls"]
|
|
|
|
if tool_calls:
|
|
assistant_msg: dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": "".join(content_parts) or None,
|
|
"tool_calls": tool_calls,
|
|
}
|
|
messages.append(assistant_msg)
|
|
self._save_message(
|
|
session_id,
|
|
"assistant",
|
|
"".join(content_parts),
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
for tool_call in tool_calls:
|
|
fn = tool_call["function"]
|
|
args = LLMClient.parse_tool_arguments(fn.get("arguments", ""))
|
|
result = await execute_tool(
|
|
self.db, fn["name"], args, session_id=session_id
|
|
)
|
|
tool_message = {
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"content": result,
|
|
}
|
|
messages.append(tool_message)
|
|
self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"])
|
|
|
|
notice = format_tool_notice(fn["name"], result)
|
|
if notice:
|
|
self._save_message(session_id, "notice", notice)
|
|
yield self._sse("notice", {"content": notice})
|
|
|
|
if fn["name"] in POMODORO_TOOL_NAMES:
|
|
yield self._sse(
|
|
"pomodoro",
|
|
{"name": fn["name"], "result": json.loads(result)},
|
|
)
|
|
|
|
continue
|
|
|
|
final_content = "".join(content_parts)
|
|
if final_content:
|
|
self._save_message(session_id, "assistant", final_content)
|
|
|
|
memory_meta: dict[str, Any] = {}
|
|
if get_settings().memory_auto_extract:
|
|
extraction = await extract_after_turn(
|
|
self.db,
|
|
session_id,
|
|
user_text,
|
|
final_content,
|
|
)
|
|
memory_meta = {
|
|
"memory_extracted": extraction.get("count", 0),
|
|
"memory_saved": extraction.get("saved", []),
|
|
}
|
|
|
|
yield self._sse("done", memory_meta)
|
|
return
|
|
|
|
yield self._sse("error", {"message": "Too many tool call rounds"})
|
|
|
|
@staticmethod
|
|
def _sse(event: str, data: dict[str, Any]) -> str:
|
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|