270 lines
9.4 KiB
Plaintext
270 lines
9.4 KiB
Plaintext
import json
|
|
import logging
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from openai import AsyncOpenAI
|
|
|
|
from app.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMClient:
|
|
def __init__(self) -> None:
|
|
settings = get_settings()
|
|
self.model = settings.openrouter_model
|
|
self.tools_enabled = settings.openrouter_tools_enabled
|
|
self.reasoning_effort = settings.openrouter_reasoning_effort.strip().lower()
|
|
self.client = AsyncOpenAI(
|
|
api_key=settings.openrouter_api_key,
|
|
base_url=settings.openrouter_base_url,
|
|
)
|
|
|
|
def _reasoning_extra_body(self) -> dict[str, Any] | None:
|
|
if not self.reasoning_effort:
|
|
return None
|
|
if self.reasoning_effort == "none":
|
|
return {"reasoning": {"effort": "none", "exclude": True}}
|
|
return {"reasoning": {"effort": self.reasoning_effort}}
|
|
|
|
@staticmethod
|
|
def _delta_reasoning(delta: Any) -> tuple[str, list[Any]]:
|
|
parts: list[str] = []
|
|
for attr in ("reasoning", "reasoning_content"):
|
|
value = getattr(delta, attr, None)
|
|
if value:
|
|
parts.append(str(value))
|
|
|
|
details: list[Any] = []
|
|
raw_details = getattr(delta, "reasoning_details", None)
|
|
if raw_details:
|
|
if isinstance(raw_details, list):
|
|
details.extend(raw_details)
|
|
else:
|
|
details.append(raw_details)
|
|
|
|
return "".join(parts), details
|
|
|
|
@staticmethod
|
|
def _normalize_reasoning_details(details: Any) -> list[Any] | None:
|
|
if not details:
|
|
return None
|
|
items = details if isinstance(details, list) else [details]
|
|
normalized: list[Any] = []
|
|
for item in items:
|
|
if hasattr(item, "model_dump"):
|
|
normalized.append(item.model_dump())
|
|
elif isinstance(item, dict):
|
|
normalized.append(item)
|
|
else:
|
|
normalized.append(item)
|
|
return normalized or None
|
|
|
|
@staticmethod
|
|
def attach_reasoning_to_message(
|
|
message: dict[str, Any],
|
|
*,
|
|
reasoning: str = "",
|
|
reasoning_details: list[Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
if reasoning:
|
|
message["reasoning"] = reasoning
|
|
message["reasoning_content"] = reasoning
|
|
normalized = LLMClient._normalize_reasoning_details(reasoning_details)
|
|
if normalized:
|
|
message["reasoning_details"] = normalized
|
|
return message
|
|
|
|
async def stream_chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
*,
|
|
model: str | None = None,
|
|
) -> AsyncIterator[dict[str, Any]]:
|
|
use_tools = bool(tools) and self.tools_enabled
|
|
kwargs: dict[str, Any] = {
|
|
"model": model or self.model,
|
|
"messages": messages,
|
|
"stream": True,
|
|
"temperature": 0.7,
|
|
}
|
|
if use_tools:
|
|
kwargs["tools"] = tools
|
|
extra_body = self._reasoning_extra_body()
|
|
if extra_body:
|
|
kwargs["extra_body"] = extra_body
|
|
|
|
try:
|
|
stream = await self.client.chat.completions.create(**kwargs)
|
|
except Exception as exc:
|
|
logger.exception("LLM stream failed: %s", exc)
|
|
yield {"type": "error", "content": str(exc)}
|
|
yield {"type": "done", "finish_reason": "error"}
|
|
return
|
|
|
|
tool_calls: dict[int, dict[str, Any]] = {}
|
|
reasoning_parts: list[str] = []
|
|
reasoning_details: list[Any] = []
|
|
|
|
try:
|
|
async for chunk in stream:
|
|
if not chunk.choices:
|
|
continue
|
|
|
|
choice = chunk.choices[0]
|
|
delta = choice.delta
|
|
|
|
if delta.content:
|
|
yield {"type": "content", "content": delta.content}
|
|
|
|
reasoning_text, details = self._delta_reasoning(delta)
|
|
if reasoning_text:
|
|
reasoning_parts.append(reasoning_text)
|
|
if details:
|
|
reasoning_details.extend(details)
|
|
|
|
if delta.tool_calls:
|
|
for tool_call in delta.tool_calls:
|
|
idx = tool_call.index
|
|
if idx not in tool_calls:
|
|
tool_calls[idx] = {
|
|
"id": tool_call.id or "",
|
|
"type": "function",
|
|
"function": {"name": "", "arguments": ""},
|
|
}
|
|
if tool_call.id:
|
|
tool_calls[idx]["id"] = tool_call.id
|
|
if tool_call.function:
|
|
if tool_call.function.name:
|
|
tool_calls[idx]["function"]["name"] = tool_call.function.name
|
|
if tool_call.function.arguments:
|
|
tool_calls[idx]["function"]["arguments"] += tool_call.function.arguments
|
|
|
|
if choice.finish_reason:
|
|
reasoning = "".join(reasoning_parts)
|
|
normalized_details = self._normalize_reasoning_details(reasoning_details)
|
|
if reasoning or normalized_details:
|
|
yield {
|
|
"type": "reasoning",
|
|
"reasoning": reasoning,
|
|
"reasoning_details": normalized_details,
|
|
}
|
|
if tool_calls:
|
|
yield {"type": "tool_calls", "tool_calls": list(tool_calls.values())}
|
|
logger.info(
|
|
"LLM stream done: model=%s finish_reason=%s tool_calls=%d "
|
|
"content_in_stream=%d reasoning_len=%d",
|
|
model or self.model,
|
|
choice.finish_reason,
|
|
len(tool_calls),
|
|
len(reasoning_parts),
|
|
len(reasoning),
|
|
)
|
|
yield {"type": "done", "finish_reason": choice.finish_reason}
|
|
except Exception as exc:
|
|
logger.exception("LLM stream read failed: %s", exc)
|
|
yield {"type": "error", "content": str(exc)}
|
|
yield {"type": "done", "finish_reason": "error"}
|
|
|
|
async def complete(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
*,
|
|
temperature: float = 0.7,
|
|
model: str | None = None,
|
|
for_extraction: bool = False,
|
|
visible_reply: bool = False,
|
|
) -> dict[str, Any]:
|
|
use_tools = bool(tools) and self.tools_enabled and not for_extraction
|
|
kwargs: dict[str, Any] = {
|
|
"model": model or self.model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
if use_tools:
|
|
kwargs["tools"] = tools
|
|
if for_extraction:
|
|
kwargs["extra_body"] = {"reasoning": {"effort": "none"}}
|
|
else:
|
|
extra_body = self._reasoning_extra_body()
|
|
if extra_body:
|
|
kwargs["extra_body"] = extra_body
|
|
|
|
response = await self.client.chat.completions.create(**kwargs)
|
|
message = response.choices[0].message
|
|
|
|
content = message.content or ""
|
|
reasoning = ""
|
|
for attr in ("reasoning", "reasoning_content"):
|
|
value = getattr(message, attr, None)
|
|
if value:
|
|
reasoning = str(value)
|
|
break
|
|
|
|
if not content and reasoning and not visible_reply:
|
|
content = reasoning
|
|
|
|
result: dict[str, Any] = {
|
|
"content": content,
|
|
"tool_calls": [],
|
|
"reasoning": reasoning,
|
|
"reasoning_details": getattr(message, "reasoning_details", None),
|
|
}
|
|
|
|
if message.tool_calls:
|
|
result["tool_calls"] = [
|
|
{
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in message.tool_calls
|
|
]
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def parse_tool_arguments(arguments: str) -> dict[str, Any]:
|
|
if not arguments:
|
|
return {}
|
|
try:
|
|
return json.loads(arguments)
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def serialize_reasoning(
|
|
*,
|
|
reasoning: str = "",
|
|
reasoning_details: list[Any] | None = None,
|
|
) -> str | None:
|
|
payload: dict[str, Any] = {}
|
|
if reasoning:
|
|
payload["reasoning"] = reasoning
|
|
payload["reasoning_content"] = reasoning
|
|
if reasoning_details:
|
|
payload["reasoning_details"] = reasoning_details
|
|
if not payload:
|
|
return None
|
|
return json.dumps(payload, ensure_ascii=False)
|
|
|
|
@staticmethod
|
|
def deserialize_reasoning(raw: str | None) -> dict[str, Any]:
|
|
if not raw:
|
|
return {}
|
|
try:
|
|
data = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return {"reasoning": raw}
|
|
if isinstance(data, str):
|
|
return {"reasoning": data, "reasoning_content": data}
|
|
if isinstance(data, dict):
|
|
return data
|
|
return {}
|