From 89158930ee9b7671fa8cb08197a73e942bcc9d0b Mon Sep 17 00:00:00 2001 From: grigo Date: Wed, 10 Jun 2026 14:37:27 +0300 Subject: [PATCH] fixed reasoning --- backend/app/api/schemas.py | 1 + backend/app/chat/service.py | 53 +++++++++++++++++++++++------------ backend/app/llm/client.py | 12 +++++--- backend/app/memory/extract.py | 1 + frontend/src/api/client.ts | 1 + frontend/src/pages/Chat.tsx | 1 + 6 files changed, 47 insertions(+), 22 deletions(-) diff --git a/backend/app/api/schemas.py b/backend/app/api/schemas.py index 798a205..fb2db66 100644 --- a/backend/app/api/schemas.py +++ b/backend/app/api/schemas.py @@ -20,6 +20,7 @@ class MessageOut(BaseModel): id: int role: str content: str + tool_calls_json: str | None = None created_at: datetime model_config = {"from_attributes": True} diff --git a/backend/app/chat/service.py b/backend/app/chat/service.py index 6ea5f3b..4cea54b 100644 --- a/backend/app/chat/service.py +++ b/backend/app/chat/service.py @@ -1,4 +1,6 @@ +import asyncio import json +import logging from collections.abc import AsyncIterator from typing import Any @@ -6,6 +8,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from app.config import get_settings +from app.db.base import SessionLocal from app.character.service import CharacterService from app.chat.notices import ( POMODORO_TOOL_NAMES, @@ -31,6 +34,22 @@ from app.tools.registry import TOOL_DEFINITIONS, execute_tool MAX_TOOL_ROUNDS = 5 MAX_HISTORY_MESSAGES = 40 +logger = logging.getLogger(__name__) + + +async def _extract_memory_background( + session_id: int, + user_text: str, + assistant_text: str, +) -> None: + db = SessionLocal() + try: + await extract_after_turn(db, session_id, user_text, assistant_text) + except Exception as exc: + logger.warning("Background memory extraction failed: %s", exc) + finally: + db.close() + class ChatService: def __init__(self, db: Session): @@ -148,6 +167,7 @@ class ChatService: self._save_message(session_id, "user", user_text) messages = self._build_messages(session) + streamed_reply_parts: list[str] = [] for _ in range(MAX_TOOL_ROUNDS): content_parts: list[str] = [] @@ -170,9 +190,13 @@ class ChatService: tool_calls = event["tool_calls"] if tool_calls: + round_text = "".join(content_parts) + if round_text.strip(): + streamed_reply_parts.append(round_text) + assistant_msg: dict[str, Any] = { "role": "assistant", - "content": "".join(content_parts) or None, + "content": round_text or None, "tool_calls": tool_calls, } LLMClient.attach_reasoning_to_message( @@ -188,7 +212,7 @@ class ChatService: self._save_message( session_id, "assistant", - "".join(content_parts), + round_text, tool_calls=tool_calls, reasoning_json=reasoning_json, ) @@ -220,10 +244,12 @@ class ChatService: continue - final_content = "".join(content_parts) - if not final_content.strip() and reasoning: - final_content = reasoning - if not final_content.strip(): + final_content = "".join(content_parts).strip() + if not final_content and streamed_reply_parts: + final_content = "".join(streamed_reply_parts).strip() + if not final_content and reasoning: + final_content = reasoning.strip() + if not final_content: yield self._sse( "error", { @@ -238,20 +264,11 @@ class ChatService: self._save_message(session_id, "assistant", final_content) - memory_meta: dict[str, Any] = {} + yield self._sse("done", {}) if get_settings().memory_auto_extract: - extraction = await extract_after_turn( - self.db, - session_id, - user_text, - final_content, + asyncio.create_task( + _extract_memory_background(session_id, user_text, final_content) ) - memory_meta = { - "memory_extracted": extraction.get("count", 0), - "memory_saved": extraction.get("saved", []), - } - - yield self._sse("done", memory_meta) return yield self._sse("error", {"message": "Too many tool call rounds"}) diff --git a/backend/app/llm/client.py b/backend/app/llm/client.py index 7d4aab1..1b8dac2 100644 --- a/backend/app/llm/client.py +++ b/backend/app/llm/client.py @@ -147,8 +147,9 @@ class LLMClient: *, temperature: float = 0.7, model: str | None = None, + for_extraction: bool = False, ) -> dict[str, Any]: - use_tools = bool(tools) and self.tools_enabled + use_tools = bool(tools) and self.tools_enabled and not for_extraction kwargs: dict[str, Any] = { "model": model or self.model, "messages": messages, @@ -156,9 +157,12 @@ class LLMClient: } if use_tools: kwargs["tools"] = tools - extra_body = self._reasoning_extra_body() - if extra_body: - kwargs["extra_body"] = extra_body + if for_extraction: + kwargs["extra_body"] = {"reasoning": {"effort": "none"}} + else: + extra_body = self._reasoning_extra_body() + if extra_body: + kwargs["extra_body"] = extra_body response = await self.client.chat.completions.create(**kwargs) message = response.choices[0].message diff --git a/backend/app/memory/extract.py b/backend/app/memory/extract.py index 053aa80..7e591bd 100644 --- a/backend/app/memory/extract.py +++ b/backend/app/memory/extract.py @@ -82,6 +82,7 @@ async def _call_extractor( ], temperature=0.2, model=extract_model, + for_extraction=True, ) raw = strip_markdown_json(result.get("content") or "") if not raw: diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 54c7548..66ebf2c 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -11,6 +11,7 @@ export interface ChatMessage { id: number; role: string; content: string; + tool_calls_json?: string | null; created_at: string; } diff --git a/frontend/src/pages/Chat.tsx b/frontend/src/pages/Chat.tsx index c4c9ead..aab80fd 100644 --- a/frontend/src/pages/Chat.tsx +++ b/frontend/src/pages/Chat.tsx @@ -7,6 +7,7 @@ import "./Chat.css"; function shouldShowMessage(msg: ChatMessage): boolean { if (msg.role === "tool") return false; + if (msg.role === "assistant" && msg.tool_calls_json) return false; if (msg.role === "assistant" && !msg.content.trim()) return false; return true; }