fixed reasoning
This commit is contained in:
+71
-33
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
@@ -6,65 +7,94 @@ from openai import AsyncOpenAI
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self) -> None:
|
||||
settings = get_settings()
|
||||
self.model = settings.openrouter_model
|
||||
self.tools_enabled = settings.openrouter_tools_enabled
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=settings.openrouter_api_key,
|
||||
base_url=settings.openrouter_base_url,
|
||||
)
|
||||
|
||||
def _delta_text(self, delta: Any) -> str:
|
||||
parts: list[str] = []
|
||||
if getattr(delta, "content", None):
|
||||
parts.append(delta.content)
|
||||
# Reasoning-модели (OpenRouter / o-series) иногда пишут сюда, а не в content.
|
||||
for attr in ("reasoning", "reasoning_content"):
|
||||
value = getattr(delta, attr, None)
|
||||
if value:
|
||||
parts.append(str(value))
|
||||
return "".join(parts)
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
use_tools = bool(tools) and self.tools_enabled
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"model": model or self.model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if tools:
|
||||
if use_tools:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
stream = await self.client.chat.completions.create(**kwargs)
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(**kwargs)
|
||||
except Exception as exc:
|
||||
logger.exception("LLM stream failed: %s", exc)
|
||||
yield {"type": "error", "content": str(exc)}
|
||||
yield {"type": "done", "finish_reason": "error"}
|
||||
return
|
||||
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
try:
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
if delta.content:
|
||||
yield {"type": "content", "content": delta.content}
|
||||
text = self._delta_text(delta)
|
||||
if text:
|
||||
yield {"type": "content", "content": text}
|
||||
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
idx = tool_call.index
|
||||
if idx not in tool_calls:
|
||||
tool_calls[idx] = {
|
||||
"id": tool_call.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
if tool_call.id:
|
||||
tool_calls[idx]["id"] = tool_call.id
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_calls[idx]["function"]["name"] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
tool_calls[idx]["function"]["arguments"] += tool_call.function.arguments
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
idx = tool_call.index
|
||||
if idx not in tool_calls:
|
||||
tool_calls[idx] = {
|
||||
"id": tool_call.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
if tool_call.id:
|
||||
tool_calls[idx]["id"] = tool_call.id
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_calls[idx]["function"]["name"] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
tool_calls[idx]["function"]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if choice.finish_reason:
|
||||
if tool_calls:
|
||||
yield {"type": "tool_calls", "tool_calls": list(tool_calls.values())}
|
||||
yield {"type": "done", "finish_reason": choice.finish_reason}
|
||||
if choice.finish_reason:
|
||||
if tool_calls:
|
||||
yield {"type": "tool_calls", "tool_calls": list(tool_calls.values())}
|
||||
yield {"type": "done", "finish_reason": choice.finish_reason}
|
||||
except Exception as exc:
|
||||
logger.exception("LLM stream read failed: %s", exc)
|
||||
yield {"type": "error", "content": str(exc)}
|
||||
yield {"type": "done", "finish_reason": "error"}
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
@@ -72,20 +102,28 @@ class LLMClient:
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
model: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
use_tools = bool(tools) and self.tools_enabled
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"model": model or self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
if use_tools:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
response = await self.client.chat.completions.create(**kwargs)
|
||||
message = response.choices[0].message
|
||||
|
||||
content = message.content or ""
|
||||
for attr in ("reasoning", "reasoning_content"):
|
||||
value = getattr(message, attr, None)
|
||||
if value and not content:
|
||||
content = str(value)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"content": message.content or "",
|
||||
"content": content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user