fixed reasoning

This commit is contained in:
2026-06-10 14:37:27 +03:00
parent 905d756a25
commit 89158930ee
6 changed files with 47 additions and 22 deletions
+1
View File
@@ -20,6 +20,7 @@ class MessageOut(BaseModel):
id: int id: int
role: str role: str
content: str content: str
tool_calls_json: str | None = None
created_at: datetime created_at: datetime
model_config = {"from_attributes": True} model_config = {"from_attributes": True}
+35 -18
View File
@@ -1,4 +1,6 @@
import asyncio
import json import json
import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any from typing import Any
@@ -6,6 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import get_settings from app.config import get_settings
from app.db.base import SessionLocal
from app.character.service import CharacterService from app.character.service import CharacterService
from app.chat.notices import ( from app.chat.notices import (
POMODORO_TOOL_NAMES, POMODORO_TOOL_NAMES,
@@ -31,6 +34,22 @@ from app.tools.registry import TOOL_DEFINITIONS, execute_tool
MAX_TOOL_ROUNDS = 5 MAX_TOOL_ROUNDS = 5
MAX_HISTORY_MESSAGES = 40 MAX_HISTORY_MESSAGES = 40
logger = logging.getLogger(__name__)
async def _extract_memory_background(
session_id: int,
user_text: str,
assistant_text: str,
) -> None:
db = SessionLocal()
try:
await extract_after_turn(db, session_id, user_text, assistant_text)
except Exception as exc:
logger.warning("Background memory extraction failed: %s", exc)
finally:
db.close()
class ChatService: class ChatService:
def __init__(self, db: Session): def __init__(self, db: Session):
@@ -148,6 +167,7 @@ class ChatService:
self._save_message(session_id, "user", user_text) self._save_message(session_id, "user", user_text)
messages = self._build_messages(session) messages = self._build_messages(session)
streamed_reply_parts: list[str] = []
for _ in range(MAX_TOOL_ROUNDS): for _ in range(MAX_TOOL_ROUNDS):
content_parts: list[str] = [] content_parts: list[str] = []
@@ -170,9 +190,13 @@ class ChatService:
tool_calls = event["tool_calls"] tool_calls = event["tool_calls"]
if tool_calls: if tool_calls:
round_text = "".join(content_parts)
if round_text.strip():
streamed_reply_parts.append(round_text)
assistant_msg: dict[str, Any] = { assistant_msg: dict[str, Any] = {
"role": "assistant", "role": "assistant",
"content": "".join(content_parts) or None, "content": round_text or None,
"tool_calls": tool_calls, "tool_calls": tool_calls,
} }
LLMClient.attach_reasoning_to_message( LLMClient.attach_reasoning_to_message(
@@ -188,7 +212,7 @@ class ChatService:
self._save_message( self._save_message(
session_id, session_id,
"assistant", "assistant",
"".join(content_parts), round_text,
tool_calls=tool_calls, tool_calls=tool_calls,
reasoning_json=reasoning_json, reasoning_json=reasoning_json,
) )
@@ -220,10 +244,12 @@ class ChatService:
continue continue
final_content = "".join(content_parts) final_content = "".join(content_parts).strip()
if not final_content.strip() and reasoning: if not final_content and streamed_reply_parts:
final_content = reasoning final_content = "".join(streamed_reply_parts).strip()
if not final_content.strip(): if not final_content and reasoning:
final_content = reasoning.strip()
if not final_content:
yield self._sse( yield self._sse(
"error", "error",
{ {
@@ -238,20 +264,11 @@ class ChatService:
self._save_message(session_id, "assistant", final_content) self._save_message(session_id, "assistant", final_content)
memory_meta: dict[str, Any] = {} yield self._sse("done", {})
if get_settings().memory_auto_extract: if get_settings().memory_auto_extract:
extraction = await extract_after_turn( asyncio.create_task(
self.db, _extract_memory_background(session_id, user_text, final_content)
session_id,
user_text,
final_content,
) )
memory_meta = {
"memory_extracted": extraction.get("count", 0),
"memory_saved": extraction.get("saved", []),
}
yield self._sse("done", memory_meta)
return return
yield self._sse("error", {"message": "Too many tool call rounds"}) yield self._sse("error", {"message": "Too many tool call rounds"})
+5 -1
View File
@@ -147,8 +147,9 @@ class LLMClient:
*, *,
temperature: float = 0.7, temperature: float = 0.7,
model: str | None = None, model: str | None = None,
for_extraction: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
use_tools = bool(tools) and self.tools_enabled use_tools = bool(tools) and self.tools_enabled and not for_extraction
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model or self.model, "model": model or self.model,
"messages": messages, "messages": messages,
@@ -156,6 +157,9 @@ class LLMClient:
} }
if use_tools: if use_tools:
kwargs["tools"] = tools kwargs["tools"] = tools
if for_extraction:
kwargs["extra_body"] = {"reasoning": {"effort": "none"}}
else:
extra_body = self._reasoning_extra_body() extra_body = self._reasoning_extra_body()
if extra_body: if extra_body:
kwargs["extra_body"] = extra_body kwargs["extra_body"] = extra_body
+1
View File
@@ -82,6 +82,7 @@ async def _call_extractor(
], ],
temperature=0.2, temperature=0.2,
model=extract_model, model=extract_model,
for_extraction=True,
) )
raw = strip_markdown_json(result.get("content") or "") raw = strip_markdown_json(result.get("content") or "")
if not raw: if not raw:
+1
View File
@@ -11,6 +11,7 @@ export interface ChatMessage {
id: number; id: number;
role: string; role: string;
content: string; content: string;
tool_calls_json?: string | null;
created_at: string; created_at: string;
} }
+1
View File
@@ -7,6 +7,7 @@ import "./Chat.css";
function shouldShowMessage(msg: ChatMessage): boolean { function shouldShowMessage(msg: ChatMessage): boolean {
if (msg.role === "tool") return false; if (msg.role === "tool") return false;
if (msg.role === "assistant" && msg.tool_calls_json) return false;
if (msg.role === "assistant" && !msg.content.trim()) return false; if (msg.role === "assistant" && !msg.content.trim()) return false;
return true; return true;
} }