579 lines
24 KiB
Python
579 lines
24 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import get_settings
|
|
from app.db.base import SessionLocal
|
|
from app.character.service import CharacterService
|
|
from app.chat.history import sanitize_openai_messages, strip_historical_reasoning
|
|
from app.chat.notice_inbox import DISPLAY_ONLY_ROLES
|
|
from app.chat.notices import (
|
|
POMODORO_TOOL_NAMES,
|
|
format_pomodoro_context,
|
|
format_tool_notice,
|
|
)
|
|
from app.fitness.context import format_fitness_context, get_fitness_snapshot
|
|
from app.homelab.context import format_datetime_context
|
|
from app.homelab.openmeteo import OpenMeteoClient, format_weather_snapshot
|
|
from app.memory.context import (
|
|
format_identity_hint,
|
|
format_memory_context,
|
|
get_memory_snapshot,
|
|
)
|
|
from app.memory.extract import extract_after_turn
|
|
from app.projects.context import format_projects_context, get_projects_snapshot
|
|
from app.reminders_scoped.context import format_reminders_context, get_reminders_snapshot
|
|
from app.shopping.context import format_shopping_context, get_shopping_snapshot
|
|
from app.db.models import ChatSession, Message
|
|
from app.llm.client import LLMClient
|
|
from app.pomodoro.service import PomodoroService
|
|
from app.tools.registry import TOOL_DEFINITIONS, execute_tool
|
|
from app.vision.analyze import format_vision_turn_hint
|
|
|
|
MAX_TOOL_ROUNDS = 5
|
|
MAX_HISTORY_MESSAGES = 40
|
|
_DOMAIN_CACHE: dict[str, tuple[float, str]] = {}
|
|
_DOMAIN_TTL_SEC = 60.0
|
|
|
|
_DOMAIN_KEYWORDS: dict[str, tuple[str, ...]] = {
|
|
"fitness": ("фитнес", "тренир", "калори", "еда", "вода", "вес", "workout", "meal", "белок", "жир"),
|
|
"shopping": ("покуп", "магазин", "список", "shopping", "корзин"),
|
|
"reminders": ("напомин", "календар", "событи", "дедлайн", "встреч", "план"),
|
|
"projects": ("taiga", "gitea", "задач", "проект", "git", "issue", "коммит", "ветк"),
|
|
"weather": (
|
|
"погод", "дожд", "снег", "ветер", "температур", "градус", "мороз", "жар",
|
|
"на улице", "одеть", "зонт", "прогноз", "завтра", "послезавтра", "выходн",
|
|
"weather", "rain", "forecast", "umbrella", "outside",
|
|
),
|
|
}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _build_messages_for_session(session_id: int, user_id: int) -> list[dict[str, Any]]:
|
|
db = SessionLocal()
|
|
try:
|
|
service = ChatService(db, user_id)
|
|
session = service.get_session(session_id)
|
|
if not session:
|
|
return []
|
|
return service._build_messages(session)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
async def _extract_memory_background(
|
|
session_id: int,
|
|
user_id: int,
|
|
user_text: str,
|
|
assistant_text: str,
|
|
) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
await extract_after_turn(db, session_id, user_text, assistant_text, user_id=user_id)
|
|
except Exception as exc:
|
|
logger.warning("Background memory extraction failed: %s", exc)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
class ChatService:
|
|
def __init__(self, db: Session, user_id: int):
|
|
self.db = db
|
|
self.user_id = user_id
|
|
self.llm = LLMClient()
|
|
self.character = CharacterService(db, user_id)
|
|
|
|
def list_sessions(self) -> list[ChatSession]:
|
|
stmt = select(ChatSession).where(ChatSession.user_id == self.user_id).order_by(ChatSession.updated_at.desc())
|
|
return list(self.db.scalars(stmt).all())
|
|
|
|
def get_session(self, session_id: int) -> ChatSession | None:
|
|
session = self.db.get(ChatSession, session_id)
|
|
if session and session.user_id != self.user_id:
|
|
return None
|
|
return session
|
|
|
|
def list_messages(
|
|
self,
|
|
session_id: int,
|
|
limit: int = 30,
|
|
before_id: int | None = None,
|
|
after_id: int | None = None,
|
|
) -> tuple[list[Message], bool]:
|
|
if not self.get_session(session_id):
|
|
return [], False
|
|
|
|
if after_id is not None:
|
|
stmt = (
|
|
select(Message)
|
|
.where(Message.session_id == session_id, Message.id > after_id)
|
|
.order_by(Message.created_at.asc())
|
|
.limit(limit + 1)
|
|
)
|
|
rows = list(self.db.scalars(stmt).all())
|
|
has_more = len(rows) > limit
|
|
return rows[:limit], has_more
|
|
|
|
stmt = select(Message).where(Message.session_id == session_id)
|
|
|
|
if before_id is not None:
|
|
anchor = self.db.get(Message, before_id)
|
|
if anchor is None or anchor.session_id != session_id:
|
|
return [], False
|
|
stmt = stmt.where(Message.created_at < anchor.created_at)
|
|
|
|
stmt = stmt.order_by(Message.created_at.desc()).limit(limit + 1)
|
|
rows = list(self.db.scalars(stmt).all())
|
|
has_more = len(rows) > limit
|
|
page = rows[:limit]
|
|
page.reverse()
|
|
return page, has_more
|
|
|
|
def create_session(self, title: str = "Новый чат") -> ChatSession:
|
|
session = ChatSession(user_id=self.user_id, title=title)
|
|
self.db.add(session)
|
|
self.db.commit()
|
|
self.db.refresh(session)
|
|
return session
|
|
|
|
def delete_session(self, session_id: int) -> bool:
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
return False
|
|
self.db.delete(session)
|
|
self.db.commit()
|
|
return True
|
|
|
|
def _cached_domain(self, key: str, loader, formatter) -> str:
|
|
now = time.monotonic()
|
|
hit = _DOMAIN_CACHE.get(f"{self.user_id}:{key}")
|
|
if hit and now < hit[0]:
|
|
return hit[1]
|
|
rendered = formatter(loader())
|
|
_DOMAIN_CACHE[f"{self.user_id}:{key}"] = (now + _DOMAIN_TTL_SEC, rendered)
|
|
return rendered
|
|
|
|
def _domain_relevant(self, key: str, user_query: str) -> bool:
|
|
query = user_query.strip().lower()
|
|
if not query:
|
|
return False
|
|
keywords = _DOMAIN_KEYWORDS.get(key, ())
|
|
return any(kw in query for kw in keywords)
|
|
|
|
def _optional_domain(
|
|
self,
|
|
key: str,
|
|
user_query: str,
|
|
loader,
|
|
formatter,
|
|
) -> str:
|
|
if not self._domain_relevant(key, user_query):
|
|
return ""
|
|
return self._cached_domain(key, loader, formatter)
|
|
|
|
def _build_system_prompt(self, session_id: int | None = None, user_query: str = "") -> str:
|
|
status = PomodoroService(self.db, self.user_id).get_status()
|
|
memory_snapshot = get_memory_snapshot(self.db, self.user_id, session_id, query=user_query)
|
|
fitness_snapshot = get_fitness_snapshot(self.db, self.user_id)
|
|
shopping_snapshot = get_shopping_snapshot(self.db, self.user_id)
|
|
reminders_snapshot = get_reminders_snapshot(self.db, self.user_id)
|
|
projects_snapshot = get_projects_snapshot(self.db, self.user_id)
|
|
parts = [
|
|
self.character.get_system_prompt(),
|
|
format_datetime_context(self.db, self.user_id),
|
|
format_memory_context(memory_snapshot),
|
|
self._optional_domain("fitness", user_query, lambda: fitness_snapshot, format_fitness_context),
|
|
self._optional_domain("shopping", user_query, lambda: shopping_snapshot, format_shopping_context),
|
|
self._optional_domain("reminders", user_query, lambda: reminders_snapshot, format_reminders_context),
|
|
self._optional_domain(
|
|
"weather",
|
|
user_query,
|
|
lambda: OpenMeteoClient().fetch_forecast(hours_ahead=6, days_ahead=7),
|
|
lambda snap: format_weather_snapshot(snap, include_daily=True),
|
|
),
|
|
format_pomodoro_context(status),
|
|
self._optional_domain("projects", user_query, lambda: projects_snapshot, format_projects_context),
|
|
]
|
|
return "\n\n".join(part for part in parts if part.strip())
|
|
|
|
def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]:
|
|
all_chat = [m for m in session.messages if m.role not in DISPLAY_ONLY_ROLES]
|
|
last_user = next((m.content for m in reversed(all_chat) if m.role == "user"), "")
|
|
system_prompt = self._build_system_prompt(session.id, user_query=last_user)
|
|
if last_user:
|
|
memory_snapshot = get_memory_snapshot(self.db, self.user_id, session.id, query=last_user)
|
|
identity_hint = format_identity_hint(memory_snapshot, last_user)
|
|
if identity_hint:
|
|
system_prompt += f"\n\n{identity_hint}"
|
|
vision_hint = format_vision_turn_hint(last_user)
|
|
if vision_hint:
|
|
system_prompt += f"\n\n{vision_hint}"
|
|
if len(all_chat) > MAX_HISTORY_MESSAGES:
|
|
system_prompt += (
|
|
f"\n\n[История чата: в контексте последние {MAX_HISTORY_MESSAGES} "
|
|
f"из {len(all_chat)} сообщений. Раннее — в сводке сессии, если сохранена.]"
|
|
)
|
|
messages: list[dict[str, Any]] = [
|
|
{"role": "system", "content": system_prompt}
|
|
]
|
|
chat_messages = all_chat[-MAX_HISTORY_MESSAGES:] if len(all_chat) > MAX_HISTORY_MESSAGES else all_chat
|
|
|
|
for msg in chat_messages:
|
|
content = msg.content or None
|
|
entry: dict[str, Any] = {"role": msg.role, "content": content}
|
|
if msg.tool_calls_json:
|
|
entry["tool_calls"] = json.loads(msg.tool_calls_json)
|
|
if not content:
|
|
entry["content"] = None
|
|
reasoning_data = LLMClient.deserialize_reasoning(msg.reasoning_json)
|
|
if reasoning_data:
|
|
LLMClient.attach_reasoning_to_message(
|
|
entry,
|
|
reasoning=reasoning_data.get("reasoning", ""),
|
|
reasoning_details=reasoning_data.get("reasoning_details"),
|
|
)
|
|
if msg.role == "tool" and msg.tool_call_id:
|
|
entry["tool_call_id"] = msg.tool_call_id
|
|
messages.append(entry)
|
|
messages = sanitize_openai_messages(messages)
|
|
messages = strip_historical_reasoning(messages)
|
|
return messages
|
|
|
|
def _save_message(
|
|
self,
|
|
session_id: int,
|
|
role: str,
|
|
content: str = "",
|
|
tool_calls: list[dict[str, Any]] | None = None,
|
|
tool_call_id: str | None = None,
|
|
reasoning_json: str | None = None,
|
|
) -> Message:
|
|
message = Message(
|
|
session_id=session_id,
|
|
role=role,
|
|
content=content,
|
|
tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None,
|
|
reasoning_json=reasoning_json,
|
|
tool_call_id=tool_call_id,
|
|
)
|
|
self.db.add(message)
|
|
session = self.get_session(session_id)
|
|
if session and role == "user" and session.title == "Новый чат" and content:
|
|
session.title = content[:60] + ("..." if len(content) > 60 else "")
|
|
self.db.commit()
|
|
self.db.refresh(message)
|
|
return message
|
|
|
|
def save_user_message(self, session_id: int, user_text: str) -> None:
|
|
self._save_message(session_id, "user", user_text)
|
|
|
|
async def _fallback_complete(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
session_id: int,
|
|
) -> tuple[str, list[str], list[dict[str, Any]]]:
|
|
"""Нестриминговый запасной путь, если stream вернул пустоту."""
|
|
logger.info("chat session=%s fallback complete", session_id)
|
|
result: dict[str, Any] = {"content": "", "tool_calls": []}
|
|
for with_tools in (True, False):
|
|
result = await self.llm.complete(
|
|
messages,
|
|
tools=TOOL_DEFINITIONS if with_tools else None,
|
|
temperature=0.5,
|
|
visible_reply=True,
|
|
)
|
|
if (result.get("content") or "").strip() or result.get("tool_calls"):
|
|
break
|
|
|
|
tool_calls = result.get("tool_calls") or []
|
|
content = (result.get("content") or "").strip()
|
|
notices: list[str] = []
|
|
pomodoro_events: list[dict[str, Any]] = []
|
|
|
|
if tool_calls:
|
|
assistant_msg: dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": content or None,
|
|
"tool_calls": tool_calls,
|
|
}
|
|
messages.append(assistant_msg)
|
|
self._save_message(
|
|
session_id,
|
|
"assistant",
|
|
content,
|
|
tool_calls=tool_calls,
|
|
)
|
|
for tool_call in tool_calls:
|
|
fn = tool_call["function"]
|
|
args = LLMClient.parse_tool_arguments(fn.get("arguments", ""))
|
|
tool_result = await execute_tool(
|
|
self.db, fn["name"], args, session_id=session_id, user_id=self.user_id
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"content": tool_result,
|
|
}
|
|
)
|
|
self._save_message(
|
|
session_id,
|
|
"tool",
|
|
tool_result,
|
|
tool_call_id=tool_call["id"],
|
|
)
|
|
notice = format_tool_notice(fn["name"], tool_result)
|
|
if notice:
|
|
self._save_message(session_id, "notice", notice)
|
|
notices.append(notice)
|
|
if fn["name"] in POMODORO_TOOL_NAMES:
|
|
pomodoro_events.append(
|
|
{"name": fn["name"], "result": json.loads(tool_result)}
|
|
)
|
|
|
|
followup = await self.llm.complete(
|
|
messages,
|
|
tools=None,
|
|
temperature=0.4,
|
|
visible_reply=True,
|
|
)
|
|
return (followup.get("content") or "").strip(), notices, pomodoro_events
|
|
|
|
return content, notices, pomodoro_events
|
|
|
|
def context_preview(self, session_id: int, query: str | None = None) -> dict[str, Any]:
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
return {"ok": False, "error": "Session not found"}
|
|
all_chat = [m for m in session.messages if m.role not in DISPLAY_ONLY_ROLES]
|
|
last_user = query or next((m.content for m in reversed(all_chat) if m.role == "user"), "")
|
|
system_prompt = self._build_system_prompt(session_id, user_query=last_user)
|
|
memory_snapshot = get_memory_snapshot(self.db, self.user_id, session_id, query=last_user)
|
|
return {
|
|
"ok": True,
|
|
"session_id": session_id,
|
|
"query": last_user,
|
|
"system_prompt_chars": len(system_prompt),
|
|
"memory_facts": len(memory_snapshot.get("facts") or []),
|
|
"memory_total_facts": memory_snapshot.get("total_facts", 0),
|
|
"system_prompt_preview": system_prompt[:4000],
|
|
}
|
|
|
|
async def stream_response(
|
|
self,
|
|
session_id: int,
|
|
user_text: str,
|
|
*,
|
|
user_message_saved: bool = False,
|
|
) -> AsyncIterator[str]:
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
yield self._sse("error", {"message": "Session not found"})
|
|
return
|
|
|
|
if not user_message_saved:
|
|
self._save_message(session_id, "user", user_text)
|
|
yield self._sse("status", {"phase": "preparing"})
|
|
t0 = time.monotonic()
|
|
messages = await asyncio.to_thread(_build_messages_for_session, session_id, self.user_id)
|
|
prepare_sec = time.monotonic() - t0
|
|
if not messages:
|
|
yield self._sse("error", {"message": "Session not found"})
|
|
return
|
|
yield self._sse("status", {"phase": "generating"})
|
|
streamed_reply_parts: list[str] = []
|
|
all_tool_notices: list[str] = []
|
|
tools_executed = 0
|
|
tool_round = 0
|
|
|
|
for _ in range(MAX_TOOL_ROUNDS):
|
|
tool_round += 1
|
|
t_round = time.monotonic()
|
|
content_parts: list[str] = []
|
|
tool_calls: list[dict[str, Any]] = []
|
|
reasoning = ""
|
|
reasoning_details: list[Any] | None = None
|
|
finish_reason = ""
|
|
|
|
# После tool-раунда стримим вживую; до tools — буфер (иначе текст «переписывает» notice).
|
|
stream_live = tools_executed > 0
|
|
|
|
async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS):
|
|
if event["type"] == "content":
|
|
content_parts.append(event["content"])
|
|
if stream_live:
|
|
yield self._sse("token", {"content": event["content"]})
|
|
elif event["type"] == "reasoning":
|
|
reasoning = event.get("reasoning", "") or reasoning
|
|
if event.get("reasoning_details"):
|
|
reasoning_details = event["reasoning_details"]
|
|
elif event["type"] == "error":
|
|
logger.warning(
|
|
"chat session=%s llm_error round=%d prepare=%.2fs: %s",
|
|
session_id,
|
|
tool_round,
|
|
prepare_sec,
|
|
event.get("content"),
|
|
)
|
|
yield self._sse("error", {"message": event.get("content", "LLM error")})
|
|
return
|
|
elif event["type"] == "tool_calls":
|
|
tool_calls = event["tool_calls"]
|
|
elif event["type"] == "done":
|
|
finish_reason = event.get("finish_reason", "")
|
|
|
|
logger.info(
|
|
"chat session=%s round=%d prepare=%.2fs llm=%.2fs "
|
|
"content_len=%d tool_calls=%d finish_reason=%s reasoning_len=%d",
|
|
session_id,
|
|
tool_round,
|
|
prepare_sec,
|
|
time.monotonic() - t_round,
|
|
len("".join(content_parts)),
|
|
len(tool_calls),
|
|
finish_reason,
|
|
len(reasoning),
|
|
)
|
|
|
|
if tool_calls:
|
|
round_text = "".join(content_parts)
|
|
if round_text.strip():
|
|
streamed_reply_parts.append(round_text)
|
|
|
|
assistant_msg: dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": round_text or None,
|
|
"tool_calls": tool_calls,
|
|
}
|
|
LLMClient.attach_reasoning_to_message(
|
|
assistant_msg,
|
|
reasoning=reasoning,
|
|
reasoning_details=reasoning_details,
|
|
)
|
|
reasoning_json = LLMClient.serialize_reasoning(
|
|
reasoning=reasoning,
|
|
reasoning_details=reasoning_details,
|
|
)
|
|
messages.append(assistant_msg)
|
|
self._save_message(
|
|
session_id,
|
|
"assistant",
|
|
round_text,
|
|
tool_calls=tool_calls,
|
|
reasoning_json=reasoning_json,
|
|
)
|
|
|
|
round_notices: list[str] = []
|
|
for tool_call in tool_calls:
|
|
fn = tool_call["function"]
|
|
args = LLMClient.parse_tool_arguments(fn.get("arguments", ""))
|
|
result = await execute_tool(
|
|
self.db, fn["name"], args, session_id=session_id, user_id=self.user_id
|
|
)
|
|
tools_executed += 1
|
|
tool_message = {
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"content": result,
|
|
}
|
|
messages.append(tool_message)
|
|
self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"])
|
|
|
|
notice = format_tool_notice(fn["name"], result)
|
|
if notice:
|
|
self._save_message(session_id, "notice", notice)
|
|
round_notices.append(notice)
|
|
all_tool_notices.append(notice)
|
|
|
|
if fn["name"] in POMODORO_TOOL_NAMES:
|
|
yield self._sse(
|
|
"pomodoro",
|
|
{"name": fn["name"], "result": json.loads(result)},
|
|
)
|
|
|
|
yield self._sse("status", {"phase": "tools"})
|
|
for notice in round_notices:
|
|
yield self._sse("notice", {"content": notice})
|
|
|
|
continue
|
|
|
|
if content_parts and not stream_live:
|
|
for part in content_parts:
|
|
yield self._sse("token", {"content": part})
|
|
|
|
final_content = "".join(content_parts).strip()
|
|
if not final_content and streamed_reply_parts and tools_executed == 0:
|
|
final_content = "".join(streamed_reply_parts).strip()
|
|
if not final_content and reasoning:
|
|
final_content = reasoning.strip()
|
|
if not final_content and tools_executed:
|
|
retry = await self.llm.complete(
|
|
messages,
|
|
tools=None,
|
|
temperature=0.4,
|
|
visible_reply=True,
|
|
)
|
|
final_content = (retry.get("content") or "").strip()
|
|
if final_content:
|
|
yield self._sse("token", {"content": final_content})
|
|
# Notices уже в чате как role=notice — не дублируем в assistant.
|
|
if not final_content:
|
|
final_content, fb_notices, fb_pomodoro = await self._fallback_complete(
|
|
messages, session_id
|
|
)
|
|
if final_content:
|
|
yield self._sse("token", {"content": final_content})
|
|
for notice in fb_notices:
|
|
yield self._sse("notice", {"content": notice})
|
|
for event in fb_pomodoro:
|
|
yield self._sse("pomodoro", event)
|
|
|
|
if not final_content:
|
|
logger.warning(
|
|
"chat session=%s empty_reply tools=%d rounds=%d finish_reason=%s",
|
|
session_id,
|
|
tools_executed,
|
|
tool_round,
|
|
finish_reason,
|
|
)
|
|
yield self._sse(
|
|
"error",
|
|
{
|
|
"message": (
|
|
"Модель не вернула ответ (finish_reason="
|
|
f"{finish_reason or 'unknown'}). "
|
|
"Попробуй новый чат или проверь OPENROUTER_MODEL."
|
|
),
|
|
},
|
|
)
|
|
return
|
|
|
|
self._save_message(session_id, "assistant", final_content)
|
|
|
|
logger.info(
|
|
"chat session=%s done tools=%d reply_len=%d total=%.2fs",
|
|
session_id,
|
|
tools_executed,
|
|
len(final_content),
|
|
time.monotonic() - t0,
|
|
)
|
|
yield self._sse("done", {})
|
|
if get_settings().memory_auto_extract:
|
|
asyncio.create_task(
|
|
_extract_memory_background(session_id, self.user_id, user_text, final_content)
|
|
)
|
|
return
|
|
|
|
yield self._sse("error", {"message": "Too many tool call rounds"})
|
|
|
|
@staticmethod
|
|
def _sse(event: str, data: dict[str, Any]) -> str:
|
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|