import asyncio import json import logging import time from collections.abc import AsyncIterator from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from app.config import get_settings from app.db.base import SessionLocal from app.character.service import CharacterService from app.chat.history import sanitize_openai_messages, strip_historical_reasoning from app.chat.notice_inbox import DISPLAY_ONLY_ROLES from app.chat.notices import ( POMODORO_TOOL_NAMES, format_pomodoro_context, format_tool_notice, ) from app.fitness.context import format_fitness_context, get_fitness_snapshot from app.homelab.context import format_datetime_context from app.homelab.openmeteo import format_weather_snapshot from app.memory.context import ( format_identity_hint, format_memory_context, get_memory_snapshot, ) from app.memory.extract import extract_after_turn from app.projects.context import format_projects_context, get_projects_snapshot from app.reminders_scoped.context import format_reminders_context, get_reminders_snapshot from app.shopping.context import format_shopping_context, get_shopping_snapshot from app.db.models import ChatSession, Message from app.llm.client import LLMClient from app.pomodoro.service import PomodoroService from app.tools.registry import TOOL_DEFINITIONS, execute_tool MAX_TOOL_ROUNDS = 5 MAX_HISTORY_MESSAGES = 40 _DOMAIN_CACHE: dict[str, tuple[float, str]] = {} _DOMAIN_TTL_SEC = 60.0 _DOMAIN_KEYWORDS: dict[str, tuple[str, ...]] = { "fitness": ("фитнес", "тренир", "калори", "еда", "вода", "вес", "workout", "meal", "белок", "жир"), "shopping": ("покуп", "магазин", "список", "shopping", "корзин"), "reminders": ("напомин", "календар", "событи", "дедлайн", "встреч", "план"), "projects": ("taiga", "gitea", "задач", "проект", "git", "issue", "коммит", "ветк"), } logger = logging.getLogger(__name__) def _build_messages_for_session(session_id: int, user_id: int) -> list[dict[str, Any]]: db = SessionLocal() try: service = ChatService(db, user_id) session = service.get_session(session_id) if not session: return [] return service._build_messages(session) finally: db.close() async def _extract_memory_background( session_id: int, user_id: int, user_text: str, assistant_text: str, ) -> None: db = SessionLocal() try: await extract_after_turn(db, session_id, user_text, assistant_text, user_id=user_id) except Exception as exc: logger.warning("Background memory extraction failed: %s", exc) finally: db.close() class ChatService: def __init__(self, db: Session, user_id: int): self.db = db self.user_id = user_id self.llm = LLMClient() self.character = CharacterService(db, user_id) def list_sessions(self) -> list[ChatSession]: stmt = select(ChatSession).where(ChatSession.user_id == self.user_id).order_by(ChatSession.updated_at.desc()) return list(self.db.scalars(stmt).all()) def get_session(self, session_id: int) -> ChatSession | None: session = self.db.get(ChatSession, session_id) if session and session.user_id != self.user_id: return None return session def list_messages( self, session_id: int, limit: int = 30, before_id: int | None = None, after_id: int | None = None, ) -> tuple[list[Message], bool]: if not self.get_session(session_id): return [], False if after_id is not None: stmt = ( select(Message) .where(Message.session_id == session_id, Message.id > after_id) .order_by(Message.created_at.asc()) .limit(limit + 1) ) rows = list(self.db.scalars(stmt).all()) has_more = len(rows) > limit return rows[:limit], has_more stmt = select(Message).where(Message.session_id == session_id) if before_id is not None: anchor = self.db.get(Message, before_id) if anchor is None or anchor.session_id != session_id: return [], False stmt = stmt.where(Message.created_at < anchor.created_at) stmt = stmt.order_by(Message.created_at.desc()).limit(limit + 1) rows = list(self.db.scalars(stmt).all()) has_more = len(rows) > limit page = rows[:limit] page.reverse() return page, has_more def create_session(self, title: str = "Новый чат") -> ChatSession: session = ChatSession(user_id=self.user_id, title=title) self.db.add(session) self.db.commit() self.db.refresh(session) return session def delete_session(self, session_id: int) -> bool: session = self.get_session(session_id) if not session: return False self.db.delete(session) self.db.commit() return True def _cached_domain(self, key: str, loader, formatter) -> str: now = time.monotonic() hit = _DOMAIN_CACHE.get(f"{self.user_id}:{key}") if hit and now < hit[0]: return hit[1] rendered = formatter(loader()) _DOMAIN_CACHE[f"{self.user_id}:{key}"] = (now + _DOMAIN_TTL_SEC, rendered) return rendered def _domain_relevant(self, key: str, user_query: str) -> bool: query = user_query.strip().lower() if not query: return False keywords = _DOMAIN_KEYWORDS.get(key, ()) return any(kw in query for kw in keywords) def _optional_domain( self, key: str, user_query: str, loader, formatter, ) -> str: if not self._domain_relevant(key, user_query): return "" return self._cached_domain(key, loader, formatter) def _build_system_prompt(self, session_id: int | None = None, user_query: str = "") -> str: status = PomodoroService(self.db, self.user_id).get_status() memory_snapshot = get_memory_snapshot(self.db, self.user_id, session_id, query=user_query) fitness_snapshot = get_fitness_snapshot(self.db, self.user_id) shopping_snapshot = get_shopping_snapshot(self.db, self.user_id) reminders_snapshot = get_reminders_snapshot(self.db, self.user_id) projects_snapshot = get_projects_snapshot(self.db, self.user_id) parts = [ self.character.get_system_prompt(), format_datetime_context(self.db, self.user_id), format_memory_context(memory_snapshot), self._optional_domain("fitness", user_query, lambda: fitness_snapshot, format_fitness_context), self._optional_domain("shopping", user_query, lambda: shopping_snapshot, format_shopping_context), self._optional_domain("reminders", user_query, lambda: reminders_snapshot, format_reminders_context), format_weather_snapshot(), format_pomodoro_context(status), self._optional_domain("projects", user_query, lambda: projects_snapshot, format_projects_context), ] return "\n\n".join(part for part in parts if part.strip()) def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]: all_chat = [m for m in session.messages if m.role not in DISPLAY_ONLY_ROLES] last_user = next((m.content for m in reversed(all_chat) if m.role == "user"), "") system_prompt = self._build_system_prompt(session.id, user_query=last_user) if last_user: memory_snapshot = get_memory_snapshot(self.db, self.user_id, session.id, query=last_user) identity_hint = format_identity_hint(memory_snapshot, last_user) if identity_hint: system_prompt += f"\n\n{identity_hint}" if len(all_chat) > MAX_HISTORY_MESSAGES: system_prompt += ( f"\n\n[История чата: в контексте последние {MAX_HISTORY_MESSAGES} " f"из {len(all_chat)} сообщений. Раннее — в сводке сессии, если сохранена.]" ) messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt} ] chat_messages = all_chat[-MAX_HISTORY_MESSAGES:] if len(all_chat) > MAX_HISTORY_MESSAGES else all_chat for msg in chat_messages: content = msg.content or None entry: dict[str, Any] = {"role": msg.role, "content": content} if msg.tool_calls_json: entry["tool_calls"] = json.loads(msg.tool_calls_json) if not content: entry["content"] = None reasoning_data = LLMClient.deserialize_reasoning(msg.reasoning_json) if reasoning_data: LLMClient.attach_reasoning_to_message( entry, reasoning=reasoning_data.get("reasoning", ""), reasoning_details=reasoning_data.get("reasoning_details"), ) if msg.role == "tool" and msg.tool_call_id: entry["tool_call_id"] = msg.tool_call_id messages.append(entry) messages = sanitize_openai_messages(messages) messages = strip_historical_reasoning(messages) return messages def _save_message( self, session_id: int, role: str, content: str = "", tool_calls: list[dict[str, Any]] | None = None, tool_call_id: str | None = None, reasoning_json: str | None = None, ) -> Message: message = Message( session_id=session_id, role=role, content=content, tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None, reasoning_json=reasoning_json, tool_call_id=tool_call_id, ) self.db.add(message) session = self.get_session(session_id) if session and role == "user" and session.title == "Новый чат" and content: session.title = content[:60] + ("..." if len(content) > 60 else "") self.db.commit() self.db.refresh(message) return message def save_user_message(self, session_id: int, user_text: str) -> None: self._save_message(session_id, "user", user_text) async def _fallback_complete( self, messages: list[dict[str, Any]], session_id: int, ) -> tuple[str, list[str], list[dict[str, Any]]]: """Нестриминговый запасной путь, если stream вернул пустоту.""" logger.info("chat session=%s fallback complete", session_id) result: dict[str, Any] = {"content": "", "tool_calls": []} for with_tools in (True, False): result = await self.llm.complete( messages, tools=TOOL_DEFINITIONS if with_tools else None, temperature=0.5, visible_reply=True, ) if (result.get("content") or "").strip() or result.get("tool_calls"): break tool_calls = result.get("tool_calls") or [] content = (result.get("content") or "").strip() notices: list[str] = [] pomodoro_events: list[dict[str, Any]] = [] if tool_calls: assistant_msg: dict[str, Any] = { "role": "assistant", "content": content or None, "tool_calls": tool_calls, } messages.append(assistant_msg) self._save_message( session_id, "assistant", content, tool_calls=tool_calls, ) for tool_call in tool_calls: fn = tool_call["function"] args = LLMClient.parse_tool_arguments(fn.get("arguments", "")) tool_result = await execute_tool( self.db, fn["name"], args, session_id=session_id, user_id=self.user_id ) messages.append( { "role": "tool", "tool_call_id": tool_call["id"], "content": tool_result, } ) self._save_message( session_id, "tool", tool_result, tool_call_id=tool_call["id"], ) notice = format_tool_notice(fn["name"], tool_result) if notice: self._save_message(session_id, "notice", notice) notices.append(notice) if fn["name"] in POMODORO_TOOL_NAMES: pomodoro_events.append( {"name": fn["name"], "result": json.loads(tool_result)} ) followup = await self.llm.complete( messages, tools=None, temperature=0.4, visible_reply=True, ) return (followup.get("content") or "").strip(), notices, pomodoro_events return content, notices, pomodoro_events def context_preview(self, session_id: int, query: str | None = None) -> dict[str, Any]: session = self.get_session(session_id) if not session: return {"ok": False, "error": "Session not found"} all_chat = [m for m in session.messages if m.role not in DISPLAY_ONLY_ROLES] last_user = query or next((m.content for m in reversed(all_chat) if m.role == "user"), "") system_prompt = self._build_system_prompt(session_id, user_query=last_user) memory_snapshot = get_memory_snapshot(self.db, self.user_id, session_id, query=last_user) return { "ok": True, "session_id": session_id, "query": last_user, "system_prompt_chars": len(system_prompt), "memory_facts": len(memory_snapshot.get("facts") or []), "memory_total_facts": memory_snapshot.get("total_facts", 0), "system_prompt_preview": system_prompt[:4000], } async def stream_response( self, session_id: int, user_text: str, *, user_message_saved: bool = False, ) -> AsyncIterator[str]: session = self.get_session(session_id) if not session: yield self._sse("error", {"message": "Session not found"}) return if not user_message_saved: self._save_message(session_id, "user", user_text) yield self._sse("status", {"phase": "preparing"}) t0 = time.monotonic() messages = await asyncio.to_thread(_build_messages_for_session, session_id, self.user_id) prepare_sec = time.monotonic() - t0 if not messages: yield self._sse("error", {"message": "Session not found"}) return yield self._sse("status", {"phase": "generating"}) streamed_reply_parts: list[str] = [] all_tool_notices: list[str] = [] tools_executed = 0 tool_round = 0 for _ in range(MAX_TOOL_ROUNDS): tool_round += 1 t_round = time.monotonic() content_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] reasoning = "" reasoning_details: list[Any] | None = None finish_reason = "" # После tool-раунда стримим вживую; до tools — буфер (иначе текст «переписывает» notice). stream_live = tools_executed > 0 async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS): if event["type"] == "content": content_parts.append(event["content"]) if stream_live: yield self._sse("token", {"content": event["content"]}) elif event["type"] == "reasoning": reasoning = event.get("reasoning", "") or reasoning if event.get("reasoning_details"): reasoning_details = event["reasoning_details"] elif event["type"] == "error": logger.warning( "chat session=%s llm_error round=%d prepare=%.2fs: %s", session_id, tool_round, prepare_sec, event.get("content"), ) yield self._sse("error", {"message": event.get("content", "LLM error")}) return elif event["type"] == "tool_calls": tool_calls = event["tool_calls"] elif event["type"] == "done": finish_reason = event.get("finish_reason", "") logger.info( "chat session=%s round=%d prepare=%.2fs llm=%.2fs " "content_len=%d tool_calls=%d finish_reason=%s reasoning_len=%d", session_id, tool_round, prepare_sec, time.monotonic() - t_round, len("".join(content_parts)), len(tool_calls), finish_reason, len(reasoning), ) if tool_calls: round_text = "".join(content_parts) if round_text.strip(): streamed_reply_parts.append(round_text) assistant_msg: dict[str, Any] = { "role": "assistant", "content": round_text or None, "tool_calls": tool_calls, } LLMClient.attach_reasoning_to_message( assistant_msg, reasoning=reasoning, reasoning_details=reasoning_details, ) reasoning_json = LLMClient.serialize_reasoning( reasoning=reasoning, reasoning_details=reasoning_details, ) messages.append(assistant_msg) self._save_message( session_id, "assistant", round_text, tool_calls=tool_calls, reasoning_json=reasoning_json, ) round_notices: list[str] = [] for tool_call in tool_calls: fn = tool_call["function"] args = LLMClient.parse_tool_arguments(fn.get("arguments", "")) result = await execute_tool( self.db, fn["name"], args, session_id=session_id, user_id=self.user_id ) tools_executed += 1 tool_message = { "role": "tool", "tool_call_id": tool_call["id"], "content": result, } messages.append(tool_message) self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"]) notice = format_tool_notice(fn["name"], result) if notice: self._save_message(session_id, "notice", notice) round_notices.append(notice) all_tool_notices.append(notice) if fn["name"] in POMODORO_TOOL_NAMES: yield self._sse( "pomodoro", {"name": fn["name"], "result": json.loads(result)}, ) yield self._sse("status", {"phase": "tools"}) for notice in round_notices: yield self._sse("notice", {"content": notice}) continue if content_parts and not stream_live: for part in content_parts: yield self._sse("token", {"content": part}) final_content = "".join(content_parts).strip() if not final_content and streamed_reply_parts and tools_executed == 0: final_content = "".join(streamed_reply_parts).strip() if not final_content and reasoning: final_content = reasoning.strip() if not final_content and tools_executed: retry = await self.llm.complete( messages, tools=None, temperature=0.4, visible_reply=True, ) final_content = (retry.get("content") or "").strip() if final_content: yield self._sse("token", {"content": final_content}) # Notices уже в чате как role=notice — не дублируем в assistant. if not final_content: final_content, fb_notices, fb_pomodoro = await self._fallback_complete( messages, session_id ) if final_content: yield self._sse("token", {"content": final_content}) for notice in fb_notices: yield self._sse("notice", {"content": notice}) for event in fb_pomodoro: yield self._sse("pomodoro", event) if not final_content: logger.warning( "chat session=%s empty_reply tools=%d rounds=%d finish_reason=%s", session_id, tools_executed, tool_round, finish_reason, ) yield self._sse( "error", { "message": ( "Модель не вернула ответ (finish_reason=" f"{finish_reason or 'unknown'}). " "Попробуй новый чат или проверь OPENROUTER_MODEL." ), }, ) return self._save_message(session_id, "assistant", final_content) logger.info( "chat session=%s done tools=%d reply_len=%d total=%.2fs", session_id, tools_executed, len(final_content), time.monotonic() - t0, ) yield self._sse("done", {}) if get_settings().memory_auto_extract: asyncio.create_task( _extract_memory_background(session_id, self.user_id, user_text, final_content) ) return yield self._sse("error", {"message": "Too many tool call rounds"}) @staticmethod def _sse(event: str, data: dict[str, Any]) -> str: return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"