Fixed SD Promt
This commit is contained in:
+119
-6
@@ -13,6 +13,8 @@ OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
CHAT_MODEL = os.getenv("CHAT_MODEL", "mistralai/mistral-nemo")
|
||||
SYSTEM_MODEL = os.getenv("SYSTEM_MODEL", "google/gemini-2.5-flash")
|
||||
# Softer model when primary returns content_filter / empty / API errors (default: CHAT_MODEL).
|
||||
LLM_FALLBACK_MODEL = (os.getenv("LLM_FALLBACK_MODEL") or "").strip() or CHAT_MODEL
|
||||
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {OPENROUTER_KEY}",
|
||||
@@ -21,26 +23,128 @@ HEADERS = {
|
||||
}
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""OpenRouter returned an error or an unexpected response shape."""
|
||||
|
||||
|
||||
def _parse_completion_body(data: dict) -> str:
|
||||
if not isinstance(data, dict):
|
||||
raise LLMError(f"Invalid API response: expected object, got {type(data).__name__}")
|
||||
|
||||
if data.get("error"):
|
||||
err = data["error"]
|
||||
if isinstance(err, dict):
|
||||
msg = err.get("message") or str(err)
|
||||
code = err.get("code")
|
||||
else:
|
||||
msg = str(err)
|
||||
code = None
|
||||
suffix = f" (code={code})" if code is not None else ""
|
||||
raise LLMError(f"OpenRouter error{suffix}: {msg}")
|
||||
|
||||
choices = data.get("choices")
|
||||
if not choices:
|
||||
preview = str(data)[:400]
|
||||
raise LLMError(f"OpenRouter response has no 'choices'. Body preview: {preview}")
|
||||
|
||||
first = choices[0] if isinstance(choices[0], dict) else {}
|
||||
message = first.get("message") or {}
|
||||
if not isinstance(message, dict):
|
||||
raise LLMError("OpenRouter choice has no message object")
|
||||
|
||||
finish = first.get("finish_reason") or ""
|
||||
native_finish = first.get("native_finish_reason") or ""
|
||||
blocked_reasons = {"content_filter", "safety", "moderation"}
|
||||
if finish in blocked_reasons or str(native_finish).upper() in (
|
||||
"PROHIBITED_CONTENT",
|
||||
"SAFETY",
|
||||
"BLOCKED",
|
||||
):
|
||||
raise LLMError(
|
||||
f"Content blocked by provider (finish_reason={finish}, native={native_finish})"
|
||||
)
|
||||
|
||||
content = message.get("content")
|
||||
if content is not None and str(content).strip():
|
||||
return str(content)
|
||||
|
||||
refusal = message.get("refusal")
|
||||
if refusal:
|
||||
raise LLMError(f"Model refused the request: {refusal}")
|
||||
|
||||
if finish and finish not in ("stop", "length", "tool_calls", "function_call"):
|
||||
raise LLMError(
|
||||
f"OpenRouter finished without content (finish_reason={finish}, native={native_finish})"
|
||||
)
|
||||
|
||||
raise LLMError("OpenRouter returned empty message content")
|
||||
|
||||
|
||||
def _clean(messages: list) -> list:
|
||||
"""Filter out messages with empty content."""
|
||||
return [m for m in messages if (m.get("content") or "").strip()]
|
||||
|
||||
|
||||
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
async def _post_once(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
if not OPENROUTER_KEY:
|
||||
raise LLMError("ROUTER_KEY is not set in environment")
|
||||
|
||||
payload = {"model": model, "messages": _clean(messages), **(extra or {})}
|
||||
async with httpx.AsyncClient(timeout=90) as client:
|
||||
r = await client.post(OPENROUTER_URL, headers=HEADERS, json=payload)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
try:
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
raise LLMError(f"Non-JSON response (HTTP {r.status_code}): {r.text[:300]}") from e
|
||||
|
||||
if r.status_code >= 400:
|
||||
try:
|
||||
_parse_completion_body(data)
|
||||
except LLMError:
|
||||
raise
|
||||
raise LLMError(f"HTTP {r.status_code}: {data}")
|
||||
|
||||
try:
|
||||
return _parse_completion_body(data)
|
||||
except LLMError:
|
||||
logger.warning(
|
||||
"OpenRouter completion failed model=%s status=%s body=%.500s",
|
||||
model,
|
||||
r.status_code,
|
||||
data,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
"""POST completion; on failure retries once with LLM_FALLBACK_MODEL (usually CHAT_MODEL)."""
|
||||
try:
|
||||
return await _post_once(model, messages, extra)
|
||||
except LLMError as primary_err:
|
||||
fallback = LLM_FALLBACK_MODEL
|
||||
if not fallback or fallback == model:
|
||||
raise
|
||||
logger.info(
|
||||
"LLM fallback: %s failed (%s) → retrying with %s",
|
||||
model,
|
||||
primary_err,
|
||||
fallback,
|
||||
)
|
||||
try:
|
||||
return await _post_once(fallback, messages, extra)
|
||||
except LLMError as fallback_err:
|
||||
raise LLMError(
|
||||
f"{primary_err} (fallback {fallback} also failed: {fallback_err})"
|
||||
) from fallback_err
|
||||
|
||||
|
||||
async def send_message(messages: list) -> str:
|
||||
"""System model — narrator, facts, SD prompt."""
|
||||
"""SYSTEM_MODEL with automatic fallback to LLM_FALLBACK_MODEL."""
|
||||
return await _post(SYSTEM_MODEL, messages)
|
||||
|
||||
|
||||
async def send_message_with_model(messages: list, model: str) -> str:
|
||||
"""Explicit model — plot arc, narrator override."""
|
||||
"""Named model (RPG_*, SD_*) with automatic fallback to LLM_FALLBACK_MODEL."""
|
||||
return await _post(model, messages)
|
||||
|
||||
|
||||
@@ -73,10 +177,19 @@ async def stream_message(messages: list):
|
||||
return
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
content = chunk["choices"][0]["delta"].get("content", "")
|
||||
if chunk.get("error"):
|
||||
err = chunk["error"]
|
||||
msg = err.get("message", err) if isinstance(err, dict) else err
|
||||
raise LLMError(f"OpenRouter stream error: {msg}")
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
content = (choices[0].get("delta") or {}).get("content", "")
|
||||
if content:
|
||||
chunk_count += 1
|
||||
yield content
|
||||
except LLMError:
|
||||
raise
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user