fix migration
This commit is contained in:
@@ -9,12 +9,24 @@ DEPRECATED_VISION_MODELS: dict[str, str] = {
|
|||||||
"google/gemini-2.0-flash-lite": "google/gemini-2.5-flash-lite",
|
"google/gemini-2.0-flash-lite": "google/gemini-2.5-flash-lite",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEPRECATED_EXTRACT_MODELS: dict[str, str] = {
|
||||||
|
"google/gemini-2.0-flash-001": "google/gemini-2.5-flash-lite",
|
||||||
|
"google/gemini-2.0-flash": "google/gemini-2.5-flash-lite",
|
||||||
|
"google/gemini-2.0-flash-lite-001": "google/gemini-2.5-flash-lite",
|
||||||
|
"google/gemini-2.0-flash-lite": "google/gemini-2.5-flash-lite",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def resolve_vision_model(model: str) -> str:
|
def resolve_vision_model(model: str) -> str:
|
||||||
stripped = model.strip()
|
stripped = model.strip()
|
||||||
return DEPRECATED_VISION_MODELS.get(stripped, stripped)
|
return DEPRECATED_VISION_MODELS.get(stripped, stripped)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_extract_model(model: str) -> str:
|
||||||
|
stripped = model.strip()
|
||||||
|
return DEPRECATED_EXTRACT_MODELS.get(stripped, stripped)
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=(".env", "../.env"),
|
env_file=(".env", "../.env"),
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ class LLMClient:
|
|||||||
base_url=settings.openrouter_base_url,
|
base_url=settings.openrouter_base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self.client.close()
|
||||||
|
|
||||||
def _runtime(self) -> tuple[str, str, str]:
|
def _runtime(self) -> tuple[str, str, str]:
|
||||||
from app.db.base import SessionLocal
|
from app.db.base import SessionLocal
|
||||||
from app.settings.service import SettingsService
|
from app.settings.service import SettingsService
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ from typing import Any
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import get_settings
|
from app.db.base import SessionLocal
|
||||||
from app.llm.client import LLMClient
|
from app.llm.client import LLMClient
|
||||||
from app.memory.service import MemoryService
|
from app.memory.service import MemoryService
|
||||||
from app.projects.structuring import strip_markdown_json
|
from app.projects.structuring import strip_markdown_json
|
||||||
|
from app.settings.service import SettingsService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,10 +63,14 @@ async def _call_extractor(
|
|||||||
*[f"- {f.get('content')}" for f in facts[:30]],
|
*[f"- {f.get('content')}" for f in facts[:30]],
|
||||||
]
|
]
|
||||||
|
|
||||||
settings = get_settings()
|
db = SessionLocal()
|
||||||
extract_model = settings.memory_extract_model.strip() or None
|
try:
|
||||||
|
extract_model = str(SettingsService(db).get_effective("memory_extract_model")).strip() or None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
llm = LLMClient()
|
llm = LLMClient()
|
||||||
|
try:
|
||||||
result = await llm.complete(
|
result = await llm.complete(
|
||||||
[
|
[
|
||||||
{"role": "system", "content": EXTRACTION_PROMPT},
|
{"role": "system", "content": EXTRACTION_PROMPT},
|
||||||
@@ -84,6 +89,8 @@ async def _call_extractor(
|
|||||||
model=extract_model,
|
model=extract_model,
|
||||||
for_extraction=True,
|
for_extraction=True,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
await llm.aclose()
|
||||||
raw = strip_markdown_json(result.get("content") or "")
|
raw = strip_markdown_json(result.get("content") or "")
|
||||||
if not raw:
|
if not raw:
|
||||||
return {"facts": [], "profile": {}}
|
return {"facts": [], "profile": {}}
|
||||||
|
|||||||
@@ -7,4 +7,7 @@ async def embed_texts(texts: list[str]) -> list[list[float]]:
|
|||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
client = LLMClient()
|
client = LLMClient()
|
||||||
|
try:
|
||||||
return await client.embed(texts)
|
return await client.embed(texts)
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import Settings, get_settings, resolve_vision_model
|
from app.config import Settings, get_settings, resolve_extract_model, resolve_vision_model
|
||||||
from app.db.models import AssistantState
|
from app.db.models import AssistantState
|
||||||
|
|
||||||
SETTING_KEYS = (
|
SETTING_KEYS = (
|
||||||
@@ -69,6 +69,8 @@ class SettingsService:
|
|||||||
return self._default_for(key)
|
return self._default_for(key)
|
||||||
if key == "openrouter_vision_model":
|
if key == "openrouter_vision_model":
|
||||||
return resolve_vision_model(raw.strip())
|
return resolve_vision_model(raw.strip())
|
||||||
|
if key == "memory_extract_model":
|
||||||
|
return resolve_extract_model(raw.strip())
|
||||||
return raw
|
return raw
|
||||||
|
|
||||||
def snapshot(self) -> dict[str, Any]:
|
def snapshot(self) -> dict[str, Any]:
|
||||||
|
|||||||
Reference in New Issue
Block a user