added RAG, Multiuser, TG bot
This commit is contained in:
+158
-70
@@ -1,70 +1,158 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.schemas import MessageCreate, MessageOut, SessionCreate, SessionDetailOut, SessionOut
|
||||
from app.chat.service import ChatService
|
||||
from app.db.base import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=SessionOut)
|
||||
def create_session(payload: SessionCreate, db: Session = Depends(get_db)) -> SessionOut:
|
||||
service = ChatService(db)
|
||||
return service.create_session(title=payload.title)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[SessionOut])
|
||||
def list_sessions(db: Session = Depends(get_db)) -> list[SessionOut]:
|
||||
service = ChatService(db)
|
||||
return service.list_sessions()
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionDetailOut)
|
||||
def get_session(session_id: int, db: Session = Depends(get_db)) -> SessionDetailOut:
|
||||
service = ChatService(db)
|
||||
session = service.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
def delete_session(session_id: int, db: Session = Depends(get_db)) -> dict[str, bool]:
|
||||
service = ChatService(db)
|
||||
if not service.delete_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/messages")
|
||||
async def send_message(
|
||||
session_id: int,
|
||||
payload: MessageCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
service = ChatService(db)
|
||||
if not service.get_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Сохраняем user до стрима: иначе при обрыве SSE сообщение не попадает в БД.
|
||||
service.save_user_message(session_id, payload.content)
|
||||
|
||||
async def event_stream():
|
||||
async for chunk in service.stream_response(
|
||||
session_id,
|
||||
payload.content,
|
||||
user_message_saved=True,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.chat_schemas import GenerationStatusOut, MessagesPageOut
|
||||
from app.api.schemas import (
|
||||
MessageCreate,
|
||||
SessionCreate,
|
||||
SessionDetailOut,
|
||||
SessionOut,
|
||||
)
|
||||
from app.chat.generation import (
|
||||
GenerationBusyError,
|
||||
get_active_handle,
|
||||
is_generation_active,
|
||||
start_generation,
|
||||
subscribe_generation,
|
||||
)
|
||||
from app.chat.service import ChatService
|
||||
from app.auth.deps import get_current_user
|
||||
from app.db.base import get_db
|
||||
from app.db.models import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=SessionOut)
|
||||
def create_session(payload: SessionCreate, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionOut:
|
||||
service = ChatService(db, user.id)
|
||||
return service.create_session(title=payload.title)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[SessionOut])
|
||||
def list_sessions(db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> list[SessionOut]:
|
||||
service = ChatService(db, user.id)
|
||||
return service.list_sessions()
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionDetailOut)
|
||||
def get_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionDetailOut:
|
||||
service = ChatService(db, user.id)
|
||||
session = service.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/messages", response_model=MessagesPageOut)
|
||||
def list_messages(
|
||||
session_id: int,
|
||||
limit: int = 30,
|
||||
before_id: int | None = None,
|
||||
after_id: int | None = None,
|
||||
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
||||
) -> MessagesPageOut:
|
||||
service = ChatService(db, user.id)
|
||||
if not service.get_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
messages, has_more = service.list_messages(
|
||||
session_id,
|
||||
limit=min(max(limit, 1), 100),
|
||||
before_id=before_id,
|
||||
after_id=after_id,
|
||||
)
|
||||
return MessagesPageOut(messages=messages, has_more=has_more)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/generation", response_model=GenerationStatusOut)
|
||||
def generation_status(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> GenerationStatusOut:
|
||||
service = ChatService(db, user.id)
|
||||
if not service.get_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return GenerationStatusOut(active=is_generation_active(session_id))
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/generation/stream")
|
||||
async def generation_stream(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> StreamingResponse:
|
||||
service = ChatService(db, user.id)
|
||||
if not service.get_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
handle = get_active_handle(session_id)
|
||||
if not handle:
|
||||
raise HTTPException(status_code=404, detail="No active generation")
|
||||
|
||||
async def event_stream():
|
||||
async for chunk in subscribe_generation(handle):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
def delete_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> dict[str, bool]:
|
||||
service = ChatService(db, user.id)
|
||||
if not service.delete_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/messages")
|
||||
async def send_message(
|
||||
session_id: int,
|
||||
payload: MessageCreate,
|
||||
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
service = ChatService(db, user.id)
|
||||
if not service.get_session(session_id):
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if is_generation_active(session_id):
|
||||
raise HTTPException(status_code=409, detail="Generation already in progress")
|
||||
|
||||
# Сохраняем user до стрима: иначе при обрыве SSE сообщение не попадает в БД.
|
||||
service.save_user_message(session_id, payload.content)
|
||||
|
||||
try:
|
||||
handle = await start_generation(session_id, user.id, payload.content)
|
||||
except GenerationBusyError:
|
||||
raise HTTPException(status_code=409, detail="Generation already in progress") from None
|
||||
|
||||
async def event_stream():
|
||||
try:
|
||||
async for chunk in subscribe_generation(handle):
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/context-preview")
|
||||
def context_preview(
|
||||
session_id: int,
|
||||
query: str | None = None,
|
||||
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
service = ChatService(db, user.id)
|
||||
return service.context_preview(session_id, query=query)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user