Files
Home_assistant/backend/app/chat/service.py
T
2026-06-09 12:47:13 +03:00

168 lines
6.1 KiB
Python

import json
from collections.abc import AsyncIterator
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.character.service import CharacterService
from app.chat.notices import format_pomodoro_context, format_pomodoro_notice
from app.projects.context import format_projects_context, get_projects_snapshot
from app.db.models import ChatSession, Message
from app.llm.client import LLMClient
from app.pomodoro.service import PomodoroService
from app.tools.registry import TOOL_DEFINITIONS, execute_tool
MAX_TOOL_ROUNDS = 5
class ChatService:
def __init__(self, db: Session):
self.db = db
self.llm = LLMClient()
self.character = CharacterService()
def list_sessions(self) -> list[ChatSession]:
stmt = select(ChatSession).order_by(ChatSession.updated_at.desc())
return list(self.db.scalars(stmt).all())
def get_session(self, session_id: int) -> ChatSession | None:
return self.db.get(ChatSession, session_id)
def create_session(self, title: str = "Новый чат") -> ChatSession:
session = ChatSession(title=title)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def delete_session(self, session_id: int) -> bool:
session = self.get_session(session_id)
if not session:
return False
self.db.delete(session)
self.db.commit()
return True
def _build_system_prompt(self) -> str:
status = PomodoroService(self.db).get_status()
projects_snapshot = get_projects_snapshot(self.db)
return (
f"{self.character.get_system_prompt()}\n\n"
f"{format_pomodoro_context(status)}\n\n"
f"{format_projects_context(projects_snapshot)}"
)
def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = [
{"role": "system", "content": self._build_system_prompt()}
]
for msg in session.messages:
if msg.role == "notice":
continue
content = msg.content or None
entry: dict[str, Any] = {"role": msg.role, "content": content}
if msg.tool_calls_json:
entry["tool_calls"] = json.loads(msg.tool_calls_json)
if not content:
entry["content"] = None
if msg.role == "tool" and msg.tool_call_id:
entry["tool_call_id"] = msg.tool_call_id
messages.append(entry)
return messages
def _save_message(
self,
session_id: int,
role: str,
content: str = "",
tool_calls: list[dict[str, Any]] | None = None,
tool_call_id: str | None = None,
) -> Message:
message = Message(
session_id=session_id,
role=role,
content=content,
tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None,
tool_call_id=tool_call_id,
)
self.db.add(message)
session = self.get_session(session_id)
if session and role == "user" and session.title == "Новый чат" and content:
session.title = content[:60] + ("..." if len(content) > 60 else "")
self.db.commit()
self.db.refresh(message)
return message
async def stream_response(self, session_id: int, user_text: str) -> AsyncIterator[str]:
session = self.get_session(session_id)
if not session:
yield self._sse("error", {"message": "Session not found"})
return
self._save_message(session_id, "user", user_text)
messages = self._build_messages(session)
for _ in range(MAX_TOOL_ROUNDS):
content_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS):
if event["type"] == "content":
content_parts.append(event["content"])
yield self._sse("token", {"content": event["content"]})
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
if tool_calls:
assistant_msg: dict[str, Any] = {
"role": "assistant",
"content": "".join(content_parts) or None,
"tool_calls": tool_calls,
}
messages.append(assistant_msg)
self._save_message(
session_id,
"assistant",
"".join(content_parts),
tool_calls=tool_calls,
)
for tool_call in tool_calls:
fn = tool_call["function"]
args = LLMClient.parse_tool_arguments(fn.get("arguments", ""))
result = await execute_tool(self.db, fn["name"], args)
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": result,
}
messages.append(tool_message)
self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"])
notice = format_pomodoro_notice(fn["name"], result)
if notice:
self._save_message(session_id, "notice", notice)
yield self._sse("notice", {"content": notice})
yield self._sse(
"pomodoro",
{"name": fn["name"], "result": json.loads(result)},
)
continue
final_content = "".join(content_parts)
if final_content:
self._save_message(session_id, "assistant", final_content)
yield self._sse("done", {})
return
yield self._sse("error", {"message": "Too many tool call rounds"})
@staticmethod
def _sse(event: str, data: dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"