71 lines
2.3 KiB
Plaintext
71 lines
2.3 KiB
Plaintext
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.api.schemas import MessageCreate, MessageOut, SessionCreate, SessionDetailOut, SessionOut
|
|
from app.chat.service import ChatService
|
|
from app.db.base import get_db
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/sessions", response_model=SessionOut)
|
|
def create_session(payload: SessionCreate, db: Session = Depends(get_db)) -> SessionOut:
|
|
service = ChatService(db)
|
|
return service.create_session(title=payload.title)
|
|
|
|
|
|
@router.get("/sessions", response_model=list[SessionOut])
|
|
def list_sessions(db: Session = Depends(get_db)) -> list[SessionOut]:
|
|
service = ChatService(db)
|
|
return service.list_sessions()
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=SessionDetailOut)
|
|
def get_session(session_id: int, db: Session = Depends(get_db)) -> SessionDetailOut:
|
|
service = ChatService(db)
|
|
session = service.get_session(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return session
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
def delete_session(session_id: int, db: Session = Depends(get_db)) -> dict[str, bool]:
|
|
service = ChatService(db)
|
|
if not service.delete_session(session_id):
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return {"ok": True}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/messages")
|
|
async def send_message(
|
|
session_id: int,
|
|
payload: MessageCreate,
|
|
db: Session = Depends(get_db),
|
|
) -> StreamingResponse:
|
|
service = ChatService(db)
|
|
if not service.get_session(session_id):
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Сохраняем user до стрима: иначе при обрыве SSE сообщение не попадает в БД.
|
|
service.save_user_message(session_id, payload.content)
|
|
|
|
async def event_stream():
|
|
async for chunk in service.stream_response(
|
|
session_id,
|
|
payload.content,
|
|
user_message_saved=True,
|
|
):
|
|
yield chunk
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|