131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
from typing import Any
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.auth.deps import get_current_user
|
|
from app.db.base import get_db
|
|
from app.db.models import User
|
|
from app.db.models import ChatSession
|
|
from app.memory.extract import extract_after_turn
|
|
from app.memory.service import MemoryService
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class ProfileUpdate(BaseModel):
|
|
updates: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class FactCreate(BaseModel):
|
|
content: str = Field(min_length=1)
|
|
category: str = "fact"
|
|
importance: int = Field(default=3, ge=1, le=5)
|
|
session_id: int | None = None
|
|
|
|
|
|
class SessionSummaryUpdate(BaseModel):
|
|
summary: str = Field(min_length=1)
|
|
message_count: int = 0
|
|
|
|
|
|
class ExtractRequest(BaseModel):
|
|
session_id: int
|
|
user_text: str = Field(min_length=1)
|
|
assistant_text: str = ""
|
|
force: bool = False
|
|
|
|
|
|
@router.get("/memory")
|
|
def get_memory_snapshot(
|
|
session_id: int | None = None,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> dict[str, Any]:
|
|
return MemoryService(db, user.id).snapshot(session_id)
|
|
|
|
|
|
@router.get("/profile")
|
|
def get_profile(db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> dict[str, Any]:
|
|
return MemoryService(db, user.id).get_profile()
|
|
|
|
|
|
@router.put("/profile")
|
|
def update_profile(
|
|
payload: ProfileUpdate,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> dict[str, Any]:
|
|
try:
|
|
return MemoryService(db, user.id).update_profile(payload.updates)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
|
|
@router.get("/memory/facts")
|
|
def list_facts(
|
|
query: str | None = None,
|
|
category: str | None = None,
|
|
limit: int = 30,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> list[dict[str, Any]]:
|
|
return MemoryService(db, user.id).recall_memories(query=query, category=category, limit=limit)
|
|
|
|
|
|
@router.post("/memory/facts")
|
|
def create_fact(
|
|
payload: FactCreate,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> dict[str, Any]:
|
|
try:
|
|
return MemoryService(db, user.id).remember_fact(
|
|
payload.content,
|
|
category=payload.category,
|
|
session_id=payload.session_id,
|
|
importance=payload.importance,
|
|
source="api",
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
|
|
@router.delete("/memory/facts/{memory_id}")
|
|
def forget_fact(memory_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> dict[str, Any]:
|
|
try:
|
|
return MemoryService(db, user.id).forget_memory(memory_id)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
|
|
|
|
@router.post("/memory/extract")
|
|
async def extract_memories(
|
|
payload: ExtractRequest,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> dict:
|
|
session = db.get(ChatSession, payload.session_id)
|
|
if not session or session.user_id != user.id:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return await extract_after_turn(
|
|
db,
|
|
payload.session_id,
|
|
payload.user_text,
|
|
payload.assistant_text,
|
|
user_id=user.id,
|
|
force=payload.force,
|
|
)
|
|
|
|
|
|
@router.put("/memory/sessions/{session_id}/summary")
|
|
def update_session_summary(
|
|
session_id: int,
|
|
payload: SessionSummaryUpdate,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> dict[str, Any]:
|
|
try:
|
|
return MemoryService(db, user.id).update_session_summary(
|
|
session_id,
|
|
payload.summary,
|
|
message_count=payload.message_count,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|