87 lines
3.1 KiB
Python
87 lines
3.1 KiB
Python
import httpx
|
|
import json
|
|
import logging
|
|
import os
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OPENROUTER_KEY = os.getenv("ROUTER_KEY")
|
|
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
|
|
|
CHAT_MODEL = os.getenv("CHAT_MODEL", "mistralai/mistral-nemo")
|
|
SYSTEM_MODEL = os.getenv("SYSTEM_MODEL", "google/gemini-2.5-flash")
|
|
|
|
HEADERS = {
|
|
"Authorization": f"Bearer {OPENROUTER_KEY}",
|
|
"Content-Type": "application/json",
|
|
"HTTP-Referer": "http://localhost:8000",
|
|
}
|
|
|
|
|
|
def _clean(messages: list) -> list:
|
|
"""Filter out messages with empty content."""
|
|
return [m for m in messages if (m.get("content") or "").strip()]
|
|
|
|
|
|
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
|
payload = {"model": model, "messages": _clean(messages), **(extra or {})}
|
|
async with httpx.AsyncClient(timeout=90) as client:
|
|
r = await client.post(OPENROUTER_URL, headers=HEADERS, json=payload)
|
|
r.raise_for_status()
|
|
return r.json()["choices"][0]["message"]["content"]
|
|
|
|
|
|
async def send_message(messages: list) -> str:
|
|
"""System model — narrator, facts, SD prompt."""
|
|
return await _post(SYSTEM_MODEL, messages)
|
|
|
|
|
|
async def send_message_with_model(messages: list, model: str) -> str:
|
|
"""Explicit model — plot arc, narrator override."""
|
|
return await _post(model, messages)
|
|
|
|
|
|
async def stream_message(messages: list):
|
|
"""Chat model stream — roleplay dialogue."""
|
|
payload = {
|
|
"model": CHAT_MODEL,
|
|
"messages": _clean(messages),
|
|
"stream": True,
|
|
}
|
|
timeout = httpx.Timeout(connect=10, read=120, write=10, pool=5)
|
|
chunk_count = 0
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
try:
|
|
async with client.stream("POST", OPENROUTER_URL, headers=HEADERS, json=payload) as response:
|
|
response.raise_for_status()
|
|
buf = ""
|
|
async for raw in response.aiter_bytes():
|
|
text = raw.decode("utf-8", errors="replace")
|
|
if not buf and chunk_count == 0:
|
|
logger.info("stream first bytes: %.200s", text)
|
|
buf += text
|
|
while "\n" in buf:
|
|
line, buf = buf.split("\n", 1)
|
|
line = line.rstrip("\r")
|
|
if not line.startswith("data: "):
|
|
continue
|
|
data = line[6:]
|
|
if data == "[DONE]":
|
|
return
|
|
try:
|
|
chunk = json.loads(data)
|
|
content = chunk["choices"][0]["delta"].get("content", "")
|
|
if content:
|
|
chunk_count += 1
|
|
yield content
|
|
except Exception:
|
|
continue
|
|
except Exception as e:
|
|
logger.error("stream_message error after %d chunks: %s", chunk_count, e)
|
|
raise
|
|
finally:
|
|
logger.info("stream_message finished: %d chunks", chunk_count)
|