Files
Home_assistant/backend/app/api/routes/chat.py
T
2026-06-13 20:20:56 +00:00

159 lines
5.6 KiB
Python

import asyncio
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.api.chat_schemas import GenerationStatusOut, MessagesPageOut
from app.api.schemas import (
MessageCreate,
SessionCreate,
SessionDetailOut,
SessionOut,
)
from app.chat.generation import (
GenerationBusyError,
get_active_handle,
is_generation_active,
start_generation,
subscribe_generation,
)
from app.chat.service import ChatService
from app.auth.deps import get_current_user
from app.db.base import get_db
from app.db.models import User
router = APIRouter()
@router.post("/sessions", response_model=SessionOut)
def create_session(payload: SessionCreate, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionOut:
service = ChatService(db, user.id)
return service.create_session(title=payload.title)
@router.get("/sessions", response_model=list[SessionOut])
def list_sessions(db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> list[SessionOut]:
service = ChatService(db, user.id)
return service.list_sessions()
@router.get("/sessions/{session_id}", response_model=SessionDetailOut)
def get_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionDetailOut:
service = ChatService(db, user.id)
session = service.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return session
@router.get("/sessions/{session_id}/messages", response_model=MessagesPageOut)
def list_messages(
session_id: int,
limit: int = 30,
before_id: int | None = None,
after_id: int | None = None,
db: Session = Depends(get_db), user: User = Depends(get_current_user),
) -> MessagesPageOut:
service = ChatService(db, user.id)
if not service.get_session(session_id):
raise HTTPException(status_code=404, detail="Session not found")
messages, has_more = service.list_messages(
session_id,
limit=min(max(limit, 1), 100),
before_id=before_id,
after_id=after_id,
)
return MessagesPageOut(messages=messages, has_more=has_more)
@router.get("/sessions/{session_id}/generation", response_model=GenerationStatusOut)
def generation_status(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> GenerationStatusOut:
service = ChatService(db, user.id)
if not service.get_session(session_id):
raise HTTPException(status_code=404, detail="Session not found")
return GenerationStatusOut(active=is_generation_active(session_id))
@router.get("/sessions/{session_id}/generation/stream")
async def generation_stream(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> StreamingResponse:
service = ChatService(db, user.id)
if not service.get_session(session_id):
raise HTTPException(status_code=404, detail="Session not found")
handle = get_active_handle(session_id)
if not handle:
raise HTTPException(status_code=404, detail="No active generation")
async def event_stream():
async for chunk in subscribe_generation(handle):
yield chunk
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.delete("/sessions/{session_id}")
def delete_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> dict[str, bool]:
service = ChatService(db, user.id)
if not service.delete_session(session_id):
raise HTTPException(status_code=404, detail="Session not found")
return {"ok": True}
@router.post("/sessions/{session_id}/messages")
async def send_message(
session_id: int,
payload: MessageCreate,
db: Session = Depends(get_db), user: User = Depends(get_current_user),
) -> StreamingResponse:
service = ChatService(db, user.id)
if not service.get_session(session_id):
raise HTTPException(status_code=404, detail="Session not found")
if is_generation_active(session_id):
raise HTTPException(status_code=409, detail="Generation already in progress")
# Сохраняем user до стрима: иначе при обрыве SSE сообщение не попадает в БД.
service.save_user_message(session_id, payload.content)
try:
handle = await start_generation(session_id, user.id, payload.content)
except GenerationBusyError:
raise HTTPException(status_code=409, detail="Generation already in progress") from None
async def event_stream():
try:
async for chunk in subscribe_generation(handle):
yield chunk
except asyncio.CancelledError:
raise
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/sessions/{session_id}/context-preview")
def context_preview(
session_id: int,
query: str | None = None,
db: Session = Depends(get_db), user: User = Depends(get_current_user),
) -> dict:
service = ChatService(db, user.id)
return service.context_preview(session_id, query=query)