fixed reasoning
This commit is contained in:
+35
-18
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
@@ -6,6 +8,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import get_settings
|
||||
from app.db.base import SessionLocal
|
||||
from app.character.service import CharacterService
|
||||
from app.chat.notices import (
|
||||
POMODORO_TOOL_NAMES,
|
||||
@@ -31,6 +34,22 @@ from app.tools.registry import TOOL_DEFINITIONS, execute_tool
|
||||
MAX_TOOL_ROUNDS = 5
|
||||
MAX_HISTORY_MESSAGES = 40
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _extract_memory_background(
|
||||
session_id: int,
|
||||
user_text: str,
|
||||
assistant_text: str,
|
||||
) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
await extract_after_turn(db, session_id, user_text, assistant_text)
|
||||
except Exception as exc:
|
||||
logger.warning("Background memory extraction failed: %s", exc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class ChatService:
|
||||
def __init__(self, db: Session):
|
||||
@@ -148,6 +167,7 @@ class ChatService:
|
||||
|
||||
self._save_message(session_id, "user", user_text)
|
||||
messages = self._build_messages(session)
|
||||
streamed_reply_parts: list[str] = []
|
||||
|
||||
for _ in range(MAX_TOOL_ROUNDS):
|
||||
content_parts: list[str] = []
|
||||
@@ -170,9 +190,13 @@ class ChatService:
|
||||
tool_calls = event["tool_calls"]
|
||||
|
||||
if tool_calls:
|
||||
round_text = "".join(content_parts)
|
||||
if round_text.strip():
|
||||
streamed_reply_parts.append(round_text)
|
||||
|
||||
assistant_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": "".join(content_parts) or None,
|
||||
"content": round_text or None,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
LLMClient.attach_reasoning_to_message(
|
||||
@@ -188,7 +212,7 @@ class ChatService:
|
||||
self._save_message(
|
||||
session_id,
|
||||
"assistant",
|
||||
"".join(content_parts),
|
||||
round_text,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_json=reasoning_json,
|
||||
)
|
||||
@@ -220,10 +244,12 @@ class ChatService:
|
||||
|
||||
continue
|
||||
|
||||
final_content = "".join(content_parts)
|
||||
if not final_content.strip() and reasoning:
|
||||
final_content = reasoning
|
||||
if not final_content.strip():
|
||||
final_content = "".join(content_parts).strip()
|
||||
if not final_content and streamed_reply_parts:
|
||||
final_content = "".join(streamed_reply_parts).strip()
|
||||
if not final_content and reasoning:
|
||||
final_content = reasoning.strip()
|
||||
if not final_content:
|
||||
yield self._sse(
|
||||
"error",
|
||||
{
|
||||
@@ -238,20 +264,11 @@ class ChatService:
|
||||
|
||||
self._save_message(session_id, "assistant", final_content)
|
||||
|
||||
memory_meta: dict[str, Any] = {}
|
||||
yield self._sse("done", {})
|
||||
if get_settings().memory_auto_extract:
|
||||
extraction = await extract_after_turn(
|
||||
self.db,
|
||||
session_id,
|
||||
user_text,
|
||||
final_content,
|
||||
asyncio.create_task(
|
||||
_extract_memory_background(session_id, user_text, final_content)
|
||||
)
|
||||
memory_meta = {
|
||||
"memory_extracted": extraction.get("count", 0),
|
||||
"memory_saved": extraction.get("saved", []),
|
||||
}
|
||||
|
||||
yield self._sse("done", memory_meta)
|
||||
return
|
||||
|
||||
yield self._sse("error", {"message": "Too many tool call rounds"})
|
||||
|
||||
Reference in New Issue
Block a user