fixed reasoning

This commit is contained in:
2026-06-10 14:37:27 +03:00
parent 905d756a25
commit 89158930ee
6 changed files with 47 additions and 22 deletions
+1
View File
@@ -20,6 +20,7 @@ class MessageOut(BaseModel):
id: int
role: str
content: str
tool_calls_json: str | None = None
created_at: datetime
model_config = {"from_attributes": True}
+35 -18
View File
@@ -1,4 +1,6 @@
import asyncio
import json
import logging
from collections.abc import AsyncIterator
from typing import Any
@@ -6,6 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import get_settings
from app.db.base import SessionLocal
from app.character.service import CharacterService
from app.chat.notices import (
POMODORO_TOOL_NAMES,
@@ -31,6 +34,22 @@ from app.tools.registry import TOOL_DEFINITIONS, execute_tool
MAX_TOOL_ROUNDS = 5
MAX_HISTORY_MESSAGES = 40
logger = logging.getLogger(__name__)
async def _extract_memory_background(
session_id: int,
user_text: str,
assistant_text: str,
) -> None:
db = SessionLocal()
try:
await extract_after_turn(db, session_id, user_text, assistant_text)
except Exception as exc:
logger.warning("Background memory extraction failed: %s", exc)
finally:
db.close()
class ChatService:
def __init__(self, db: Session):
@@ -148,6 +167,7 @@ class ChatService:
self._save_message(session_id, "user", user_text)
messages = self._build_messages(session)
streamed_reply_parts: list[str] = []
for _ in range(MAX_TOOL_ROUNDS):
content_parts: list[str] = []
@@ -170,9 +190,13 @@ class ChatService:
tool_calls = event["tool_calls"]
if tool_calls:
round_text = "".join(content_parts)
if round_text.strip():
streamed_reply_parts.append(round_text)
assistant_msg: dict[str, Any] = {
"role": "assistant",
"content": "".join(content_parts) or None,
"content": round_text or None,
"tool_calls": tool_calls,
}
LLMClient.attach_reasoning_to_message(
@@ -188,7 +212,7 @@ class ChatService:
self._save_message(
session_id,
"assistant",
"".join(content_parts),
round_text,
tool_calls=tool_calls,
reasoning_json=reasoning_json,
)
@@ -220,10 +244,12 @@ class ChatService:
continue
final_content = "".join(content_parts)
if not final_content.strip() and reasoning:
final_content = reasoning
if not final_content.strip():
final_content = "".join(content_parts).strip()
if not final_content and streamed_reply_parts:
final_content = "".join(streamed_reply_parts).strip()
if not final_content and reasoning:
final_content = reasoning.strip()
if not final_content:
yield self._sse(
"error",
{
@@ -238,20 +264,11 @@ class ChatService:
self._save_message(session_id, "assistant", final_content)
memory_meta: dict[str, Any] = {}
yield self._sse("done", {})
if get_settings().memory_auto_extract:
extraction = await extract_after_turn(
self.db,
session_id,
user_text,
final_content,
asyncio.create_task(
_extract_memory_background(session_id, user_text, final_content)
)
memory_meta = {
"memory_extracted": extraction.get("count", 0),
"memory_saved": extraction.get("saved", []),
}
yield self._sse("done", memory_meta)
return
yield self._sse("error", {"message": "Too many tool call rounds"})
+8 -4
View File
@@ -147,8 +147,9 @@ class LLMClient:
*,
temperature: float = 0.7,
model: str | None = None,
for_extraction: bool = False,
) -> dict[str, Any]:
use_tools = bool(tools) and self.tools_enabled
use_tools = bool(tools) and self.tools_enabled and not for_extraction
kwargs: dict[str, Any] = {
"model": model or self.model,
"messages": messages,
@@ -156,9 +157,12 @@ class LLMClient:
}
if use_tools:
kwargs["tools"] = tools
extra_body = self._reasoning_extra_body()
if extra_body:
kwargs["extra_body"] = extra_body
if for_extraction:
kwargs["extra_body"] = {"reasoning": {"effort": "none"}}
else:
extra_body = self._reasoning_extra_body()
if extra_body:
kwargs["extra_body"] = extra_body
response = await self.client.chat.completions.create(**kwargs)
message = response.choices[0].message
+1
View File
@@ -82,6 +82,7 @@ async def _call_extractor(
],
temperature=0.2,
model=extract_model,
for_extraction=True,
)
raw = strip_markdown_json(result.get("content") or "")
if not raw: