Files
Home_assistant/backend/app/chat/service.py
T
2026-06-10 14:37:27 +03:00

279 lines
11 KiB
Python

import asyncio
import json
import logging
from collections.abc import AsyncIterator
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import get_settings
from app.db.base import SessionLocal
from app.character.service import CharacterService
from app.chat.notices import (
POMODORO_TOOL_NAMES,
format_pomodoro_context,
format_tool_notice,
)
from app.fitness.context import format_fitness_context, get_fitness_snapshot
from app.homelab.context import format_datetime_context
from app.homelab.openmeteo import format_weather_snapshot
from app.memory.context import (
format_identity_hint,
format_memory_context,
get_memory_snapshot,
)
from app.memory.extract import extract_after_turn
from app.projects.context import format_projects_context, get_projects_snapshot
from app.shopping.context import format_shopping_context, get_shopping_snapshot
from app.db.models import ChatSession, Message
from app.llm.client import LLMClient
from app.pomodoro.service import PomodoroService
from app.tools.registry import TOOL_DEFINITIONS, execute_tool
MAX_TOOL_ROUNDS = 5
MAX_HISTORY_MESSAGES = 40
logger = logging.getLogger(__name__)
async def _extract_memory_background(
session_id: int,
user_text: str,
assistant_text: str,
) -> None:
db = SessionLocal()
try:
await extract_after_turn(db, session_id, user_text, assistant_text)
except Exception as exc:
logger.warning("Background memory extraction failed: %s", exc)
finally:
db.close()
class ChatService:
def __init__(self, db: Session):
self.db = db
self.llm = LLMClient()
self.character = CharacterService()
def list_sessions(self) -> list[ChatSession]:
stmt = select(ChatSession).order_by(ChatSession.updated_at.desc())
return list(self.db.scalars(stmt).all())
def get_session(self, session_id: int) -> ChatSession | None:
return self.db.get(ChatSession, session_id)
def create_session(self, title: str = "Новый чат") -> ChatSession:
session = ChatSession(title=title)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def delete_session(self, session_id: int) -> bool:
session = self.get_session(session_id)
if not session:
return False
self.db.delete(session)
self.db.commit()
return True
def _build_system_prompt(self, session_id: int | None = None) -> str:
status = PomodoroService(self.db).get_status()
memory_snapshot = get_memory_snapshot(self.db, session_id)
fitness_snapshot = get_fitness_snapshot(self.db)
shopping_snapshot = get_shopping_snapshot(self.db)
projects_snapshot = get_projects_snapshot(self.db)
return (
f"{self.character.get_system_prompt()}\n\n"
f"{format_datetime_context(self.db)}\n\n"
f"{format_memory_context(memory_snapshot)}\n\n"
f"{format_fitness_context(fitness_snapshot)}\n\n"
f"{format_shopping_context(shopping_snapshot)}\n\n"
f"{format_weather_snapshot()}\n\n"
f"{format_pomodoro_context(status)}\n\n"
f"{format_projects_context(projects_snapshot)}"
)
def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]:
system_prompt = self._build_system_prompt(session.id)
all_chat = [m for m in session.messages if m.role != "notice"]
last_user = next((m.content for m in reversed(all_chat) if m.role == "user"), "")
if last_user:
memory_snapshot = get_memory_snapshot(self.db, session.id)
identity_hint = format_identity_hint(memory_snapshot, last_user)
if identity_hint:
system_prompt += f"\n\n{identity_hint}"
if len(all_chat) > MAX_HISTORY_MESSAGES:
system_prompt += (
f"\n\n[История чата: в контексте последние {MAX_HISTORY_MESSAGES} "
f"из {len(all_chat)} сообщений. Раннее — в сводке сессии, если сохранена.]"
)
messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
chat_messages = all_chat[-MAX_HISTORY_MESSAGES:] if len(all_chat) > MAX_HISTORY_MESSAGES else all_chat
for msg in chat_messages:
content = msg.content or None
entry: dict[str, Any] = {"role": msg.role, "content": content}
if msg.tool_calls_json:
entry["tool_calls"] = json.loads(msg.tool_calls_json)
if not content:
entry["content"] = None
reasoning_data = LLMClient.deserialize_reasoning(msg.reasoning_json)
if reasoning_data:
LLMClient.attach_reasoning_to_message(
entry,
reasoning=reasoning_data.get("reasoning", ""),
reasoning_details=reasoning_data.get("reasoning_details"),
)
if msg.role == "tool" and msg.tool_call_id:
entry["tool_call_id"] = msg.tool_call_id
messages.append(entry)
return messages
def _save_message(
self,
session_id: int,
role: str,
content: str = "",
tool_calls: list[dict[str, Any]] | None = None,
tool_call_id: str | None = None,
reasoning_json: str | None = None,
) -> Message:
message = Message(
session_id=session_id,
role=role,
content=content,
tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None,
reasoning_json=reasoning_json,
tool_call_id=tool_call_id,
)
self.db.add(message)
session = self.get_session(session_id)
if session and role == "user" and session.title == "Новый чат" and content:
session.title = content[:60] + ("..." if len(content) > 60 else "")
self.db.commit()
self.db.refresh(message)
return message
async def stream_response(self, session_id: int, user_text: str) -> AsyncIterator[str]:
session = self.get_session(session_id)
if not session:
yield self._sse("error", {"message": "Session not found"})
return
self._save_message(session_id, "user", user_text)
messages = self._build_messages(session)
streamed_reply_parts: list[str] = []
for _ in range(MAX_TOOL_ROUNDS):
content_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
reasoning = ""
reasoning_details: list[Any] | None = None
async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS):
if event["type"] == "content":
content_parts.append(event["content"])
yield self._sse("token", {"content": event["content"]})
elif event["type"] == "reasoning":
reasoning = event.get("reasoning", "") or reasoning
if event.get("reasoning_details"):
reasoning_details = event["reasoning_details"]
elif event["type"] == "error":
yield self._sse("error", {"message": event.get("content", "LLM error")})
return
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
if tool_calls:
round_text = "".join(content_parts)
if round_text.strip():
streamed_reply_parts.append(round_text)
assistant_msg: dict[str, Any] = {
"role": "assistant",
"content": round_text or None,
"tool_calls": tool_calls,
}
LLMClient.attach_reasoning_to_message(
assistant_msg,
reasoning=reasoning,
reasoning_details=reasoning_details,
)
reasoning_json = LLMClient.serialize_reasoning(
reasoning=reasoning,
reasoning_details=reasoning_details,
)
messages.append(assistant_msg)
self._save_message(
session_id,
"assistant",
round_text,
tool_calls=tool_calls,
reasoning_json=reasoning_json,
)
for tool_call in tool_calls:
fn = tool_call["function"]
args = LLMClient.parse_tool_arguments(fn.get("arguments", ""))
result = await execute_tool(
self.db, fn["name"], args, session_id=session_id
)
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": result,
}
messages.append(tool_message)
self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"])
notice = format_tool_notice(fn["name"], result)
if notice:
self._save_message(session_id, "notice", notice)
yield self._sse("notice", {"content": notice})
if fn["name"] in POMODORO_TOOL_NAMES:
yield self._sse(
"pomodoro",
{"name": fn["name"], "result": json.loads(result)},
)
continue
final_content = "".join(content_parts).strip()
if not final_content and streamed_reply_parts:
final_content = "".join(streamed_reply_parts).strip()
if not final_content and reasoning:
final_content = reasoning.strip()
if not final_content:
yield self._sse(
"error",
{
"message": (
"Модель не вернула текст. Для deepseek-v4-pro: "
"OPENROUTER_TOOLS_ENABLED=true и OPENROUTER_REASONING_EFFORT=none. "
"Для памяти: MEMORY_EXTRACT_MODEL=deepseek/deepseek-chat."
),
},
)
return
self._save_message(session_id, "assistant", final_content)
yield self._sse("done", {})
if get_settings().memory_auto_extract:
asyncio.create_task(
_extract_memory_background(session_id, user_text, final_content)
)
return
yield self._sse("error", {"message": "Too many tool call rounds"})
@staticmethod
def _sse(event: str, data: dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"