This commit is contained in:
2026-06-09 09:36:48 +03:00
parent 8247b7116f
commit f0fda693d8
49 changed files with 5503 additions and 1 deletions
+3
View File
@@ -0,0 +1,3 @@
from app.chat.service import ChatService
__all__ = ["ChatService"]
+141
View File
@@ -0,0 +1,141 @@
import json
from collections.abc import AsyncIterator
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import get_settings
from app.db.models import ChatSession, Message
from app.llm.client import LLMClient
from app.tools.registry import TOOL_DEFINITIONS, execute_tool
MAX_TOOL_ROUNDS = 5
class ChatService:
def __init__(self, db: Session):
self.db = db
self.llm = LLMClient()
self.system_prompt = get_settings().load_system_prompt()
def list_sessions(self) -> list[ChatSession]:
stmt = select(ChatSession).order_by(ChatSession.updated_at.desc())
return list(self.db.scalars(stmt).all())
def get_session(self, session_id: int) -> ChatSession | None:
return self.db.get(ChatSession, session_id)
def create_session(self, title: str = "Новый чат") -> ChatSession:
session = ChatSession(title=title)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def delete_session(self, session_id: int) -> bool:
session = self.get_session(session_id)
if not session:
return False
self.db.delete(session)
self.db.commit()
return True
def _build_messages(self, session: ChatSession) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = [{"role": "system", "content": self.system_prompt}]
for msg in session.messages:
content = msg.content or None
entry: dict[str, Any] = {"role": msg.role, "content": content}
if msg.tool_calls_json:
entry["tool_calls"] = json.loads(msg.tool_calls_json)
if not content:
entry["content"] = None
if msg.role == "tool" and msg.tool_call_id:
entry["tool_call_id"] = msg.tool_call_id
messages.append(entry)
return messages
def _save_message(
self,
session_id: int,
role: str,
content: str = "",
tool_calls: list[dict[str, Any]] | None = None,
tool_call_id: str | None = None,
) -> Message:
message = Message(
session_id=session_id,
role=role,
content=content,
tool_calls_json=json.dumps(tool_calls, ensure_ascii=False) if tool_calls else None,
tool_call_id=tool_call_id,
)
self.db.add(message)
session = self.get_session(session_id)
if session and role == "user" and session.title == "Новый чат" and content:
session.title = content[:60] + ("..." if len(content) > 60 else "")
self.db.commit()
self.db.refresh(message)
return message
async def stream_response(self, session_id: int, user_text: str) -> AsyncIterator[str]:
session = self.get_session(session_id)
if not session:
yield self._sse("error", {"message": "Session not found"})
return
self._save_message(session_id, "user", user_text)
messages = self._build_messages(session)
for _ in range(MAX_TOOL_ROUNDS):
content_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
async for event in self.llm.stream_chat(messages, tools=TOOL_DEFINITIONS):
if event["type"] == "content":
content_parts.append(event["content"])
yield self._sse("token", {"content": event["content"]})
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
if tool_calls:
assistant_msg: dict[str, Any] = {
"role": "assistant",
"content": "".join(content_parts) or None,
"tool_calls": tool_calls,
}
messages.append(assistant_msg)
self._save_message(
session_id,
"assistant",
"".join(content_parts),
tool_calls=tool_calls,
)
for tool_call in tool_calls:
fn = tool_call["function"]
args = LLMClient.parse_tool_arguments(fn.get("arguments", ""))
result = execute_tool(self.db, fn["name"], args)
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": result,
}
messages.append(tool_message)
self._save_message(session_id, "tool", result, tool_call_id=tool_call["id"])
yield self._sse("tool", {"name": fn["name"], "result": json.loads(result)})
continue
final_content = "".join(content_parts)
if final_content:
self._save_message(session_id, "assistant", final_content)
yield self._sse("done", {})
return
yield self._sse("error", {"message": "Too many tool call rounds"})
@staticmethod
def _sse(event: str, data: dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"