fixed reasoning
This commit is contained in:
+94
-10
@@ -15,21 +15,48 @@ class LLMClient:
|
||||
settings = get_settings()
|
||||
self.model = settings.openrouter_model
|
||||
self.tools_enabled = settings.openrouter_tools_enabled
|
||||
self.reasoning_effort = settings.openrouter_reasoning_effort.strip().lower()
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=settings.openrouter_api_key,
|
||||
base_url=settings.openrouter_base_url,
|
||||
)
|
||||
|
||||
def _delta_text(self, delta: Any) -> str:
|
||||
def _reasoning_extra_body(self) -> dict[str, Any] | None:
|
||||
if not self.reasoning_effort:
|
||||
return None
|
||||
return {"reasoning": {"effort": self.reasoning_effort}}
|
||||
|
||||
@staticmethod
|
||||
def _delta_reasoning(delta: Any) -> tuple[str, list[Any]]:
|
||||
parts: list[str] = []
|
||||
if getattr(delta, "content", None):
|
||||
parts.append(delta.content)
|
||||
# Reasoning-модели (OpenRouter / o-series) иногда пишут сюда, а не в content.
|
||||
for attr in ("reasoning", "reasoning_content"):
|
||||
value = getattr(delta, attr, None)
|
||||
if value:
|
||||
parts.append(str(value))
|
||||
return "".join(parts)
|
||||
|
||||
details: list[Any] = []
|
||||
raw_details = getattr(delta, "reasoning_details", None)
|
||||
if raw_details:
|
||||
if isinstance(raw_details, list):
|
||||
details.extend(raw_details)
|
||||
else:
|
||||
details.append(raw_details)
|
||||
|
||||
return "".join(parts), details
|
||||
|
||||
@staticmethod
|
||||
def attach_reasoning_to_message(
|
||||
message: dict[str, Any],
|
||||
*,
|
||||
reasoning: str = "",
|
||||
reasoning_details: list[Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if reasoning:
|
||||
message["reasoning"] = reasoning
|
||||
message["reasoning_content"] = reasoning
|
||||
if reasoning_details:
|
||||
message["reasoning_details"] = reasoning_details
|
||||
return message
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
@@ -47,6 +74,9 @@ class LLMClient:
|
||||
}
|
||||
if use_tools:
|
||||
kwargs["tools"] = tools
|
||||
extra_body = self._reasoning_extra_body()
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(**kwargs)
|
||||
@@ -57,6 +87,8 @@ class LLMClient:
|
||||
return
|
||||
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
reasoning_parts: list[str] = []
|
||||
reasoning_details: list[Any] = []
|
||||
|
||||
try:
|
||||
async for chunk in stream:
|
||||
@@ -66,9 +98,14 @@ class LLMClient:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
text = self._delta_text(delta)
|
||||
if text:
|
||||
yield {"type": "content", "content": text}
|
||||
if delta.content:
|
||||
yield {"type": "content", "content": delta.content}
|
||||
|
||||
reasoning_text, details = self._delta_reasoning(delta)
|
||||
if reasoning_text:
|
||||
reasoning_parts.append(reasoning_text)
|
||||
if details:
|
||||
reasoning_details.extend(details)
|
||||
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
@@ -88,6 +125,13 @@ class LLMClient:
|
||||
tool_calls[idx]["function"]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if choice.finish_reason:
|
||||
reasoning = "".join(reasoning_parts)
|
||||
if reasoning or reasoning_details:
|
||||
yield {
|
||||
"type": "reasoning",
|
||||
"reasoning": reasoning,
|
||||
"reasoning_details": reasoning_details or None,
|
||||
}
|
||||
if tool_calls:
|
||||
yield {"type": "tool_calls", "tool_calls": list(tool_calls.values())}
|
||||
yield {"type": "done", "finish_reason": choice.finish_reason}
|
||||
@@ -112,19 +156,29 @@ class LLMClient:
|
||||
}
|
||||
if use_tools:
|
||||
kwargs["tools"] = tools
|
||||
extra_body = self._reasoning_extra_body()
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = await self.client.chat.completions.create(**kwargs)
|
||||
message = response.choices[0].message
|
||||
|
||||
content = message.content or ""
|
||||
reasoning = ""
|
||||
for attr in ("reasoning", "reasoning_content"):
|
||||
value = getattr(message, attr, None)
|
||||
if value and not content:
|
||||
content = str(value)
|
||||
if value:
|
||||
reasoning = str(value)
|
||||
break
|
||||
|
||||
if not content and reasoning:
|
||||
content = reasoning
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"content": content,
|
||||
"tool_calls": [],
|
||||
"reasoning": reasoning,
|
||||
"reasoning_details": getattr(message, "reasoning_details", None),
|
||||
}
|
||||
|
||||
if message.tool_calls:
|
||||
@@ -150,3 +204,33 @@ class LLMClient:
|
||||
return json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def serialize_reasoning(
|
||||
*,
|
||||
reasoning: str = "",
|
||||
reasoning_details: list[Any] | None = None,
|
||||
) -> str | None:
|
||||
payload: dict[str, Any] = {}
|
||||
if reasoning:
|
||||
payload["reasoning"] = reasoning
|
||||
payload["reasoning_content"] = reasoning
|
||||
if reasoning_details:
|
||||
payload["reasoning_details"] = reasoning_details
|
||||
if not payload:
|
||||
return None
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_reasoning(raw: str | None) -> dict[str, Any]:
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {"reasoning": raw}
|
||||
if isinstance(data, str):
|
||||
return {"reasoning": data, "reasoning_content": data}
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return {}
|
||||
|
||||
Reference in New Issue
Block a user