fixed reasoning
This commit is contained in:
@@ -20,6 +20,7 @@ class MessageOut(BaseModel):
|
|||||||
id: int
|
id: int
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
tool_calls_json: str | None = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
|||||||
+35
-18
@@ -1,4 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -6,6 +8,7 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
|
from app.db.base import SessionLocal
|
||||||
from app.character.service import CharacterService
|
from app.character.service import CharacterService
|
||||||
from app.chat.notices import (
|
from app.chat.notices import (
|
||||||
POMODORO_TOOL_NAMES,
|
POMODORO_TOOL_NAMES,
|
||||||
@@ -31,6 +34,22 @@ from app.tools.registry import TOOL_DEFINITIONS, execute_tool
|
|||||||
MAX_TOOL_ROUNDS = 5
|
MAX_TOOL_ROUNDS = 5
|
||||||
MAX_HISTORY_MESSAGES = 40
|
MAX_HISTORY_MESSAGES = 40
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_memory_background(
|
||||||
|
session_id: int,
|
||||||
|
user_text: str,
|
||||||
|
assistant_text: str,
|
||||||
|
) -> None:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
await extract_after_turn(db, session_id, user_text, assistant_text)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Background memory extraction failed: %s", exc)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
class ChatService:
|
class ChatService:
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
@@ -148,6 +167,7 @@ class ChatService:
|
|||||||
|
|
||||||
self._save_message(session_id, "user", user_text)
|
self._save_message(session_id, "user", user_text)
|
||||||
messages = self._build_messages(session)
|
messages = self._build_messages(session)
|
||||||
|
streamed_reply_parts: list[str] = []
|
||||||
|
|
||||||
for _ in range(MAX_TOOL_ROUNDS):
|
for _ in range(MAX_TOOL_ROUNDS):
|
||||||
content_parts: list[str] = []
|
content_parts: list[str] = []
|
||||||
@@ -170,9 +190,13 @@ class ChatService:
|
|||||||
tool_calls = event["tool_calls"]
|
tool_calls = event["tool_calls"]
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
|
round_text = "".join(content_parts)
|
||||||
|
if round_text.strip():
|
||||||
|
streamed_reply_parts.append(round_text)
|
||||||
|
|
||||||
assistant_msg: dict[str, Any] = {
|
assistant_msg: dict[str, Any] = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "".join(content_parts) or None,
|
"content": round_text or None,
|
||||||
"tool_calls": tool_calls,
|
"tool_calls": tool_calls,
|
||||||
}
|
}
|
||||||
LLMClient.attach_reasoning_to_message(
|
LLMClient.attach_reasoning_to_message(
|
||||||
@@ -188,7 +212,7 @@ class ChatService:
|
|||||||
self._save_message(
|
self._save_message(
|
||||||
session_id,
|
session_id,
|
||||||
"assistant",
|
"assistant",
|
||||||
"".join(content_parts),
|
round_text,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
reasoning_json=reasoning_json,
|
reasoning_json=reasoning_json,
|
||||||
)
|
)
|
||||||
@@ -220,10 +244,12 @@ class ChatService:
|
|||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
final_content = "".join(content_parts)
|
final_content = "".join(content_parts).strip()
|
||||||
if not final_content.strip() and reasoning:
|
if not final_content and streamed_reply_parts:
|
||||||
final_content = reasoning
|
final_content = "".join(streamed_reply_parts).strip()
|
||||||
if not final_content.strip():
|
if not final_content and reasoning:
|
||||||
|
final_content = reasoning.strip()
|
||||||
|
if not final_content:
|
||||||
yield self._sse(
|
yield self._sse(
|
||||||
"error",
|
"error",
|
||||||
{
|
{
|
||||||
@@ -238,20 +264,11 @@ class ChatService:
|
|||||||
|
|
||||||
self._save_message(session_id, "assistant", final_content)
|
self._save_message(session_id, "assistant", final_content)
|
||||||
|
|
||||||
memory_meta: dict[str, Any] = {}
|
yield self._sse("done", {})
|
||||||
if get_settings().memory_auto_extract:
|
if get_settings().memory_auto_extract:
|
||||||
extraction = await extract_after_turn(
|
asyncio.create_task(
|
||||||
self.db,
|
_extract_memory_background(session_id, user_text, final_content)
|
||||||
session_id,
|
|
||||||
user_text,
|
|
||||||
final_content,
|
|
||||||
)
|
)
|
||||||
memory_meta = {
|
|
||||||
"memory_extracted": extraction.get("count", 0),
|
|
||||||
"memory_saved": extraction.get("saved", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
yield self._sse("done", memory_meta)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self._sse("error", {"message": "Too many tool call rounds"})
|
yield self._sse("error", {"message": "Too many tool call rounds"})
|
||||||
|
|||||||
@@ -147,8 +147,9 @@ class LLMClient:
|
|||||||
*,
|
*,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
for_extraction: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
use_tools = bool(tools) and self.tools_enabled
|
use_tools = bool(tools) and self.tools_enabled and not for_extraction
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model or self.model,
|
"model": model or self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -156,9 +157,12 @@ class LLMClient:
|
|||||||
}
|
}
|
||||||
if use_tools:
|
if use_tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
extra_body = self._reasoning_extra_body()
|
if for_extraction:
|
||||||
if extra_body:
|
kwargs["extra_body"] = {"reasoning": {"effort": "none"}}
|
||||||
kwargs["extra_body"] = extra_body
|
else:
|
||||||
|
extra_body = self._reasoning_extra_body()
|
||||||
|
if extra_body:
|
||||||
|
kwargs["extra_body"] = extra_body
|
||||||
|
|
||||||
response = await self.client.chat.completions.create(**kwargs)
|
response = await self.client.chat.completions.create(**kwargs)
|
||||||
message = response.choices[0].message
|
message = response.choices[0].message
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ async def _call_extractor(
|
|||||||
],
|
],
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
model=extract_model,
|
model=extract_model,
|
||||||
|
for_extraction=True,
|
||||||
)
|
)
|
||||||
raw = strip_markdown_json(result.get("content") or "")
|
raw = strip_markdown_json(result.get("content") or "")
|
||||||
if not raw:
|
if not raw:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ export interface ChatMessage {
|
|||||||
id: number;
|
id: number;
|
||||||
role: string;
|
role: string;
|
||||||
content: string;
|
content: string;
|
||||||
|
tool_calls_json?: string | null;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import "./Chat.css";
|
|||||||
|
|
||||||
function shouldShowMessage(msg: ChatMessage): boolean {
|
function shouldShowMessage(msg: ChatMessage): boolean {
|
||||||
if (msg.role === "tool") return false;
|
if (msg.role === "tool") return false;
|
||||||
|
if (msg.role === "assistant" && msg.tool_calls_json) return false;
|
||||||
if (msg.role === "assistant" && !msg.content.trim()) return false;
|
if (msg.role === "assistant" && !msg.content.trim()) return false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user