105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import Settings, get_settings, resolve_extract_model, resolve_vision_model
|
|
from app.db.models import AssistantState
|
|
|
|
SETTING_KEYS = (
|
|
"openrouter_model",
|
|
"memory_extract_model",
|
|
"openrouter_vision_model",
|
|
"openrouter_reasoning_effort",
|
|
"rag_enabled",
|
|
"rag_top_k",
|
|
)
|
|
|
|
|
|
class SettingsService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self._defaults = get_settings()
|
|
|
|
def _get_row(self, key: str) -> AssistantState | None:
|
|
return self.db.get(AssistantState, key)
|
|
|
|
def get_raw(self, key: str) -> str | None:
|
|
row = self._get_row(key)
|
|
if not row or not (row.value or "").strip():
|
|
return None
|
|
return row.value.strip()
|
|
|
|
def set_raw(self, key: str, value: str) -> None:
|
|
row = self._get_row(key)
|
|
if not row:
|
|
row = AssistantState(key=key, value=value)
|
|
self.db.add(row)
|
|
else:
|
|
row.value = value
|
|
row.updated_at = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
|
|
def _default_for(self, key: str) -> Any:
|
|
defaults: Settings = self._defaults
|
|
mapping = {
|
|
"openrouter_model": defaults.openrouter_model,
|
|
"memory_extract_model": defaults.memory_extract_model or defaults.openrouter_model,
|
|
"openrouter_vision_model": defaults.openrouter_vision_model,
|
|
"openrouter_reasoning_effort": defaults.openrouter_reasoning_effort,
|
|
"rag_enabled": defaults.rag_enabled,
|
|
"rag_top_k": defaults.rag_top_k,
|
|
}
|
|
return mapping[key]
|
|
|
|
def get_effective(self, key: str) -> Any:
|
|
raw = self.get_raw(key)
|
|
if raw is None:
|
|
return self._default_for(key)
|
|
if key == "rag_enabled":
|
|
return raw.lower() in ("1", "true", "yes", "on")
|
|
if key == "rag_top_k":
|
|
try:
|
|
return max(1, min(50, int(raw)))
|
|
except ValueError:
|
|
return self._default_for(key)
|
|
if key == "openrouter_vision_model":
|
|
return resolve_vision_model(raw.strip())
|
|
if key == "memory_extract_model":
|
|
return resolve_extract_model(raw.strip())
|
|
return raw
|
|
|
|
def snapshot(self) -> dict[str, Any]:
|
|
data: dict[str, Any] = {}
|
|
for key in SETTING_KEYS:
|
|
data[key] = self.get_effective(key)
|
|
data["embedding_model"] = self._defaults.embedding_model
|
|
data["memory_facts_in_context"] = self._defaults.memory_facts_in_context
|
|
data["qdrant_url"] = self._defaults.qdrant_url
|
|
return data
|
|
|
|
def patch(self, updates: dict[str, Any]) -> dict[str, Any]:
|
|
for key, value in updates.items():
|
|
if key not in SETTING_KEYS:
|
|
continue
|
|
if value is None:
|
|
row = self._get_row(key)
|
|
if row:
|
|
self.db.delete(row)
|
|
self.db.commit()
|
|
continue
|
|
if key == "rag_enabled":
|
|
stored = "true" if bool(value) else "false"
|
|
elif key == "rag_top_k":
|
|
stored = str(int(value))
|
|
else:
|
|
stored = str(value).strip()
|
|
if not stored and key != "rag_enabled":
|
|
continue
|
|
self.set_raw(key, stored)
|
|
return self.snapshot()
|