fix migration
This commit is contained in:
@@ -9,12 +9,24 @@ DEPRECATED_VISION_MODELS: dict[str, str] = {
|
||||
"google/gemini-2.0-flash-lite": "google/gemini-2.5-flash-lite",
|
||||
}
|
||||
|
||||
DEPRECATED_EXTRACT_MODELS: dict[str, str] = {
|
||||
"google/gemini-2.0-flash-001": "google/gemini-2.5-flash-lite",
|
||||
"google/gemini-2.0-flash": "google/gemini-2.5-flash-lite",
|
||||
"google/gemini-2.0-flash-lite-001": "google/gemini-2.5-flash-lite",
|
||||
"google/gemini-2.0-flash-lite": "google/gemini-2.5-flash-lite",
|
||||
}
|
||||
|
||||
|
||||
def resolve_vision_model(model: str) -> str:
|
||||
stripped = model.strip()
|
||||
return DEPRECATED_VISION_MODELS.get(stripped, stripped)
|
||||
|
||||
|
||||
def resolve_extract_model(model: str) -> str:
|
||||
stripped = model.strip()
|
||||
return DEPRECATED_EXTRACT_MODELS.get(stripped, stripped)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=(".env", "../.env"),
|
||||
|
||||
@@ -19,6 +19,9 @@ class LLMClient:
|
||||
base_url=settings.openrouter_base_url,
|
||||
)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.client.close()
|
||||
|
||||
def _runtime(self) -> tuple[str, str, str]:
|
||||
from app.db.base import SessionLocal
|
||||
from app.settings.service import SettingsService
|
||||
|
||||
@@ -5,10 +5,11 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import get_settings
|
||||
from app.db.base import SessionLocal
|
||||
from app.llm.client import LLMClient
|
||||
from app.memory.service import MemoryService
|
||||
from app.projects.structuring import strip_markdown_json
|
||||
from app.settings.service import SettingsService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,28 +63,34 @@ async def _call_extractor(
|
||||
*[f"- {f.get('content')}" for f in facts[:30]],
|
||||
]
|
||||
|
||||
settings = get_settings()
|
||||
extract_model = settings.memory_extract_model.strip() or None
|
||||
db = SessionLocal()
|
||||
try:
|
||||
extract_model = str(SettingsService(db).get_effective("memory_extract_model")).strip() or None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
llm = LLMClient()
|
||||
result = await llm.complete(
|
||||
[
|
||||
{"role": "system", "content": EXTRACTION_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"\n".join(known)
|
||||
+ "\n\n---\nДиалог:\nПользователь: "
|
||||
+ user_text
|
||||
+ "\nАссистент: "
|
||||
+ assistant_text[:1500]
|
||||
),
|
||||
},
|
||||
],
|
||||
temperature=0.2,
|
||||
model=extract_model,
|
||||
for_extraction=True,
|
||||
)
|
||||
try:
|
||||
result = await llm.complete(
|
||||
[
|
||||
{"role": "system", "content": EXTRACTION_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"\n".join(known)
|
||||
+ "\n\n---\nДиалог:\nПользователь: "
|
||||
+ user_text
|
||||
+ "\nАссистент: "
|
||||
+ assistant_text[:1500]
|
||||
),
|
||||
},
|
||||
],
|
||||
temperature=0.2,
|
||||
model=extract_model,
|
||||
for_extraction=True,
|
||||
)
|
||||
finally:
|
||||
await llm.aclose()
|
||||
raw = strip_markdown_json(result.get("content") or "")
|
||||
if not raw:
|
||||
return {"facts": [], "profile": {}}
|
||||
|
||||
@@ -7,4 +7,7 @@ async def embed_texts(texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
client = LLMClient()
|
||||
return await client.embed(texts)
|
||||
try:
|
||||
return await client.embed(texts)
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import Settings, get_settings, resolve_vision_model
|
||||
from app.config import Settings, get_settings, resolve_extract_model, resolve_vision_model
|
||||
from app.db.models import AssistantState
|
||||
|
||||
SETTING_KEYS = (
|
||||
@@ -69,6 +69,8 @@ class SettingsService:
|
||||
return self._default_for(key)
|
||||
if key == "openrouter_vision_model":
|
||||
return resolve_vision_model(raw.strip())
|
||||
if key == "memory_extract_model":
|
||||
return resolve_extract_model(raw.strip())
|
||||
return raw
|
||||
|
||||
def snapshot(self) -> dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user