diff --git a/backend/app/config.py b/backend/app/config.py index 093d7e5..7f8d6a8 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -9,12 +9,24 @@ DEPRECATED_VISION_MODELS: dict[str, str] = { "google/gemini-2.0-flash-lite": "google/gemini-2.5-flash-lite", } +DEPRECATED_EXTRACT_MODELS: dict[str, str] = { + "google/gemini-2.0-flash-001": "google/gemini-2.5-flash-lite", + "google/gemini-2.0-flash": "google/gemini-2.5-flash-lite", + "google/gemini-2.0-flash-lite-001": "google/gemini-2.5-flash-lite", + "google/gemini-2.0-flash-lite": "google/gemini-2.5-flash-lite", +} + def resolve_vision_model(model: str) -> str: stripped = model.strip() return DEPRECATED_VISION_MODELS.get(stripped, stripped) +def resolve_extract_model(model: str) -> str: + stripped = model.strip() + return DEPRECATED_EXTRACT_MODELS.get(stripped, stripped) + + class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=(".env", "../.env"), diff --git a/backend/app/llm/client.py b/backend/app/llm/client.py index 047d163..27d5dff 100644 --- a/backend/app/llm/client.py +++ b/backend/app/llm/client.py @@ -19,6 +19,9 @@ class LLMClient: base_url=settings.openrouter_base_url, ) + async def aclose(self) -> None: + await self.client.close() + def _runtime(self) -> tuple[str, str, str]: from app.db.base import SessionLocal from app.settings.service import SettingsService diff --git a/backend/app/memory/extract.py b/backend/app/memory/extract.py index f07a405..703d280 100644 --- a/backend/app/memory/extract.py +++ b/backend/app/memory/extract.py @@ -5,10 +5,11 @@ from typing import Any from sqlalchemy.orm import Session -from app.config import get_settings +from app.db.base import SessionLocal from app.llm.client import LLMClient from app.memory.service import MemoryService from app.projects.structuring import strip_markdown_json +from app.settings.service import SettingsService logger = logging.getLogger(__name__) @@ -62,28 +63,34 @@ async def _call_extractor( *[f"- {f.get('content')}" for f in facts[:30]], ] - settings = get_settings() - extract_model = settings.memory_extract_model.strip() or None + db = SessionLocal() + try: + extract_model = str(SettingsService(db).get_effective("memory_extract_model")).strip() or None + finally: + db.close() llm = LLMClient() - result = await llm.complete( - [ - {"role": "system", "content": EXTRACTION_PROMPT}, - { - "role": "user", - "content": ( - "\n".join(known) - + "\n\n---\nДиалог:\nПользователь: " - + user_text - + "\nАссистент: " - + assistant_text[:1500] - ), - }, - ], - temperature=0.2, - model=extract_model, - for_extraction=True, - ) + try: + result = await llm.complete( + [ + {"role": "system", "content": EXTRACTION_PROMPT}, + { + "role": "user", + "content": ( + "\n".join(known) + + "\n\n---\nДиалог:\nПользователь: " + + user_text + + "\nАссистент: " + + assistant_text[:1500] + ), + }, + ], + temperature=0.2, + model=extract_model, + for_extraction=True, + ) + finally: + await llm.aclose() raw = strip_markdown_json(result.get("content") or "") if not raw: return {"facts": [], "profile": {}} diff --git a/backend/app/rag/embeddings.py b/backend/app/rag/embeddings.py index 2e057e8..f281f01 100644 --- a/backend/app/rag/embeddings.py +++ b/backend/app/rag/embeddings.py @@ -7,4 +7,7 @@ async def embed_texts(texts: list[str]) -> list[list[float]]: if not texts: return [] client = LLMClient() - return await client.embed(texts) + try: + return await client.embed(texts) + finally: + await client.aclose() diff --git a/backend/app/settings/service.py b/backend/app/settings/service.py index b5ca24e..9505404 100644 --- a/backend/app/settings/service.py +++ b/backend/app/settings/service.py @@ -7,7 +7,7 @@ from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session -from app.config import Settings, get_settings, resolve_vision_model +from app.config import Settings, get_settings, resolve_extract_model, resolve_vision_model from app.db.models import AssistantState SETTING_KEYS = ( @@ -69,6 +69,8 @@ class SettingsService: return self._default_for(key) if key == "openrouter_vision_model": return resolve_vision_model(raw.strip()) + if key == "memory_extract_model": + return resolve_extract_model(raw.strip()) return raw def snapshot(self) -> dict[str, Any]: