import asyncio import json import logging import time from collections.abc import AsyncIterator from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from app.config import get_settings from app.db.base import SessionLocal from app.character.service import CharacterService from app.chat.history import sanitize_openai_messages, strip_historical_reasoning from app.chat.notices import ( POMODORO_TOOL_NAMES, format_pomodoro_context, format_tool_notice, ) from app.fitness.context import format_fitness_context, get_fitness_snapshot from app.homelab.context import format_datetime_context from app.homelab.openmeteo import format_weather_snapshot from app.memory.context import ( format_identity_hint, format_memory_context, get_memory_snapshot, ) from app.memory.extract import extract_after_turn from app.projects.context import format_projects_context, get_projects_snapshot from app.shopping.context import format_shopping_context, get_shopping_snapshot from app.db.models import ChatSession, Message from app.llm.client import LLMClient from app.pomodoro.service import PomodoroService from app.tools.registry import TOOL_DEFINITIONS, execute_tool MAX_TOOL_ROUNDS = 5 MAX_HISTORY_MESSAGES = 40 logger = logging.getLogger(__name__) def _build_messages_for_session(session_id: int) -> list[dict[str, Any]]: db = SessionLocal() try: service = ChatService(db) session = service.get_session(session_id) if not session: return [] return service._build_messages(session) finally: db.close() async def _extract_memory_background( session_id: int, user_text: str, assistant_text: str, ) -> None: db = SessionLocal() try: await extract_after_turn(db, session_id, user_text, assistant_text) except Exception as exc: logger.warning("Background memory extraction failed: %s", exc) finally: db.close() class ChatService: def __init__(self, db: Session): self.db = db self.llm = LLMClient() self.character = CharacterService() def list_sessions(self) -> list[ChatSession]: stmt = select(ChatSession).order_by(ChatSession.updated_at.desc()) return list(self.db.scalars(stmt).all()) def get_session(self, session_id: int) -> ChatSession | None: return self.db.get(ChatSession, session_id) def create_session(self, title: str = "Новый чат") -> ChatSession: session = ChatSession(title=title) self.db.add(session) self.db.commit() self.db.refresh(session) return session def delete_session(self, session_id: int) -> bool: session = self.get_session(session_id) if not session: return False self.db.delete(session) self.db.commit() return True def _build_system_prompt(self, session_id: int | None = None) -> str: status = PomodoroService(self.db).get_status() memory_snapshot = get_memory_snapshot(self.db, session_id) fitness_snapshot = get_fitness_snapshot(self.db) shopping_snapshot = get_shopping_snapshot(self.db) projects_snapshot = get_projects_snapshot(self.db) return ( f"{self.character.get_system_prompt()}\n\n" f"{format_datetime_context(self.db)}\n\n" f"{format_memory_context(memory_snapshot)}\n\n" f"{format_fitness_context(fitness_snapshot)}\n\n" f"{format_shopping_context(shopping_snapshot)}\n\n" f"{format_weather_snapshot()}\n\n" f"{format_pomodoro_context(status)}\n\n" f"{format_projects_context(projects_snapshot)}" ) def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]: system_prompt = self._build_system_prompt(session.id) all_chat = [m for m in session.messages if m.role != "notice"] last_user = next((m.content for m in reversed(all_chat) if m.role == "user"), "") if last_user: memory_snapshot = get_memory_snapshot(self.db, session.id) identity_hint = format_identity_hint(memory_snapshot, last_user) if identity_hint: system_prompt += f"\n\n{identity_hint}" if len(all_chat) > MAX_HISTORY_MESSAGES: system_prompt += ( f"\n\n[История чата: в контексте последние {MAX_HISTORY_MESSAGES} " f"из {len(all_chat)} сообщений. Раннее — в сводке сессии, если сохранена.]" ) messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt} ] chat_messages = all_chat[-MAX_HISTORY_MESSAGES:] if len(all_chat) > MAX_HISTORY_MESSAGES else all_chat for msg in chat_messages: content = msg.content or None entry: dict[str, Any] = {"role": msg.role, "content": content} if msg.tool_calls_json: entry["tool_calls"] = json.loads(msg.tool_calls_json) if not content: entry["content"] = None reasoning_data = LLMClient.deserialize_reasoning(msg.reasoning_json) if reasoning_data: LLMClient.attach_reasoning_to_message( entry, reasoning=reasoning_data.get("reasoning", ""), reasoning_details=reasoning_data.get("reasoning_details"), ) if msg.role == "tool" and msg.tool_call_id: entry["tool_call_id"] = msg.tool_call_id messages.append(entry) messages = sanitize_openai_messages(messages) messages = strip_historical_reasoning(messages) return messages def _save_message( self, session_id: int, role: str, content: str = "", tool_calls: list[dict[str, Any]] | None = None, tool_call_id: str | None = None, reasoning_json: str | None = None, ) -> Message: message = Message( session_id=session_id, role=role, content=content, tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None, reasoning_json=reasoning_json, tool_call_id=tool_call_id, ) self.db.add(message) session = self.get_session(session_id) if session and role == "user" and session.title == "Новый чат" and content: session.title = content[:60] + ("..." if len(content) > 60 else "") self.db.commit() self.db.refresh(message) return message def save_user_message(self, session_id: int, user_text: str) -> None: self._save_message(session_id, "user", user_text) async def _fallback_complete( self, messages: list[dict[str, Any]], session_id: int, ) -> tuple[str, list[str], list[dict[str, Any]]]: """Нестриминговый запасной путь, если stream вернул пустоту.""" logger.info("chat session=%s fallback complete", session_id) result: dict[str, Any] = {"content": "", "tool_calls": []} for with_tools in (True, False): result = await self.llm.complete( messages, tools=TOOL_DEFINITIONS if with_tools else None, temperature=0.5, visible_reply=True, ) if (result.get("content") or "").strip() or result.get("tool_calls"): break tool_calls = result.get("tool_calls") or [] content = (result.get("content") or "").strip() notices: list[str] = [] pomodoro_events: list[dict[str, Any]] = [] if tool_calls: assistant_msg: dict[str, Any] = { "role": "assistant", "content": content or None, "tool_calls": tool_calls, } messages.append(assistant_msg) self._save_message( session_id, "assistant", content, tool_calls=tool_calls, ) for tool_call in tool_calls: fn = tool_call["function"] args = LLMClient.parse_tool_arguments(fn.get("arguments", "")) tool_result = await execute_tool( self.db, fn["name"], args, session_id=session_id ) messages.append( { "role": "tool", "tool_call_id": tool_call["id"], "content": tool_result, } ) self._save_message( session_id, "tool", tool_result, tool_call_id=tool_call["id"], ) notice = format_tool_notice(fn["name"], tool_result) if notice: self._save_message(session_id, "notice", notice) notices.append(notice) if fn["name"] in POMODORO_TOOL_NAMES: pomodoro_events.append( {"name": fn["name"], "result": json.loads(tool_result)} ) if notices: return "\n\n".join(notices), notices, pomodoro_events followup = await self.llm.complete( messages, tools=None, temperature=0.4, visible_reply=True, ) return (followup.get("content") or "").strip(), notices, pomodoro_events return content, notices, pomodoro_events async def stream_response( self, session_id: int, user_text: str, *, user_message_saved: bool = False, ) -> AsyncIterator[str]: session = self.get_session(session_id) if not session: yield self._sse("error", {"message": "Session not found"}) return if not user_message_saved: self._save_message(session_id, "user", user_text) yield self._sse("status", {"phase": "preparing"}) t0 = time.monotonic() messages = await asyncio.to_thread(_build_messages_for_session, session_id) prepare_sec = time.monotonic() - t0 if not messages: yield self._sse("error", {"message": "Session not found"}) return yield self._sse("status", {"phase": "generating"}) streamed_reply_parts: list[str] = [] all_tool_notices: list[str] = [] tools_executed = 0 tool_round = 0 for _ in range(MAX_TOOL_ROUNDS): tool_round += 1 t_round = time.monotonic() content_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] reasoning = "" reasoning_details: list[Any] | None = None finish_reason = "" async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS): if event["type"] == "content": content_parts.append(event["content"]) yield self._sse("token", {"content": event["content"]}) elif event["type"] == "reasoning": reasoning = event.get("reasoning", "") or reasoning if event.get("reasoning_details"): reasoning_details = event["reasoning_details"] elif event["type"] == "error": logger.warning( "chat session=%s llm_error round=%d prepare=%.2fs: %s", session_id, tool_round, prepare_sec, event.get("content"), ) yield self._sse("error", {"message": event.get("content", "LLM error")}) return elif event["type"] == "tool_calls": tool_calls = event["tool_calls"] elif event["type"] == "done": finish_reason = event.get("finish_reason", "") logger.info( "chat session=%s round=%d prepare=%.2fs llm=%.2fs " "content_len=%d tool_calls=%d finish_reason=%s reasoning_len=%d", session_id, tool_round, prepare_sec, time.monotonic() - t_round, len("".join(content_parts)), len(tool_calls), finish_reason, len(reasoning), ) if tool_calls: round_text = "".join(content_parts) if round_text.strip(): streamed_reply_parts.append(round_text) assistant_msg: dict[str, Any] = { "role": "assistant", "content": round_text or None, "tool_calls": tool_calls, } LLMClient.attach_reasoning_to_message( assistant_msg, reasoning=reasoning, reasoning_details=reasoning_details, ) reasoning_json = LLMClient.serialize_reasoning( reasoning=reasoning, reasoning_details=reasoning_details, ) messages.append(assistant_msg) self._save_message( session_id, "assistant", round_text, tool_calls=tool_calls, reasoning_json=reasoning_json, ) round_notices: list[str] = [] for tool_call in tool_calls: fn = tool_call["function"] args = LLMClient.parse_tool_arguments(fn.get("arguments", "")) result = await execute_tool( self.db, fn["name"], args, session_id=session_id ) tools_executed += 1 tool_message = { "role": "tool", "tool_call_id": tool_call["id"], "content": result, } messages.append(tool_message) self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"]) notice = format_tool_notice(fn["name"], result) if notice: self._save_message(session_id, "notice", notice) round_notices.append(notice) all_tool_notices.append(notice) if fn["name"] in POMODORO_TOOL_NAMES: yield self._sse( "pomodoro", {"name": fn["name"], "result": json.loads(result)}, ) for notice in round_notices: yield self._sse("notice", {"content": notice}) continue final_content = "".join(content_parts).strip() if not final_content and streamed_reply_parts: final_content = "".join(streamed_reply_parts).strip() if not final_content and reasoning: final_content = reasoning.strip() if not final_content and all_tool_notices: final_content = "\n\n".join(all_tool_notices) yield self._sse("token", {"content": final_content}) if not final_content and tools_executed: retry = await self.llm.complete( messages, tools=None, temperature=0.4, visible_reply=True, ) final_content = (retry.get("content") or "").strip() if final_content: yield self._sse("token", {"content": final_content}) if not final_content: final_content, fb_notices, fb_pomodoro = await self._fallback_complete( messages, session_id ) if final_content: yield self._sse("token", {"content": final_content}) for notice in fb_notices: yield self._sse("notice", {"content": notice}) for event in fb_pomodoro: yield self._sse("pomodoro", event) if not final_content: logger.warning( "chat session=%s empty_reply tools=%d rounds=%d finish_reason=%s", session_id, tools_executed, tool_round, finish_reason, ) yield self._sse( "error", { "message": ( "Модель не вернула ответ (finish_reason=" f"{finish_reason or 'unknown'}). " "Попробуй новый чат или проверь OPENROUTER_MODEL." ), }, ) return self._save_message(session_id, "assistant", final_content) logger.info( "chat session=%s done tools=%d reply_len=%d total=%.2fs", session_id, tools_executed, len(final_content), time.monotonic() - t0, ) yield self._sse("done", {}) if get_settings().memory_auto_extract: asyncio.create_task( _extract_memory_background(session_id, user_text, final_content) ) return yield self._sse("error", {"message": "Too many tool call rounds"}) @staticmethod def _sse(event: str, data: dict[str, Any]) -> str: return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"