253 lines
9.0 KiB
Python
253 lines
9.0 KiB
Python
import asyncio
|
|
import json
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.api.chat_schemas import GenerationStatusOut, MessagesPageOut
|
|
from app.api.schemas import (
|
|
MessageCreate,
|
|
SessionCreate,
|
|
SessionDetailOut,
|
|
SessionOut,
|
|
)
|
|
from app.auth.deps import get_current_user
|
|
from app.chat.generation import (
|
|
GenerationBusyError,
|
|
get_active_handle,
|
|
is_generation_active,
|
|
start_generation,
|
|
subscribe_generation,
|
|
)
|
|
from app.chat.service import ChatService
|
|
from app.config import get_settings
|
|
from app.db.base import get_db
|
|
from app.db.models import User
|
|
from app.vision import VisionService, format_user_messages, vision_debug_payloads
|
|
from app.vision.analyze import VisionUnavailableError
|
|
from app.vision.preprocess import prepare_image
|
|
from app.vision.storage import format_upload_images_markdown, save_upload
|
|
|
|
router = APIRouter()
|
|
|
|
ALLOWED_IMAGE_TYPES = {"image/jpeg", "image/png", "image/webp", "image/gif"}
|
|
|
|
|
|
@router.post("/sessions", response_model=SessionOut)
|
|
def create_session(payload: SessionCreate, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionOut:
|
|
service = ChatService(db, user.id)
|
|
return service.create_session(title=payload.title)
|
|
|
|
|
|
@router.get("/sessions", response_model=list[SessionOut])
|
|
def list_sessions(db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> list[SessionOut]:
|
|
service = ChatService(db, user.id)
|
|
return service.list_sessions()
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=SessionDetailOut)
|
|
def get_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionDetailOut:
|
|
service = ChatService(db, user.id)
|
|
session = service.get_session(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return session
|
|
|
|
|
|
@router.get("/sessions/{session_id}/messages", response_model=MessagesPageOut)
|
|
def list_messages(
|
|
session_id: int,
|
|
limit: int = 30,
|
|
before_id: int | None = None,
|
|
after_id: int | None = None,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> MessagesPageOut:
|
|
service = ChatService(db, user.id)
|
|
if not service.get_session(session_id):
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
messages, has_more = service.list_messages(
|
|
session_id,
|
|
limit=min(max(limit, 1), 100),
|
|
before_id=before_id,
|
|
after_id=after_id,
|
|
)
|
|
return MessagesPageOut(messages=messages, has_more=has_more)
|
|
|
|
|
|
@router.get("/sessions/{session_id}/generation", response_model=GenerationStatusOut)
|
|
def generation_status(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> GenerationStatusOut:
|
|
service = ChatService(db, user.id)
|
|
if not service.get_session(session_id):
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return GenerationStatusOut(active=is_generation_active(session_id))
|
|
|
|
|
|
@router.get("/sessions/{session_id}/generation/stream")
|
|
async def generation_stream(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> StreamingResponse:
|
|
service = ChatService(db, user.id)
|
|
if not service.get_session(session_id):
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
handle = get_active_handle(session_id)
|
|
if not handle:
|
|
raise HTTPException(status_code=404, detail="No active generation")
|
|
|
|
async def event_stream():
|
|
async for chunk in subscribe_generation(handle):
|
|
yield chunk
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
def delete_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> dict[str, bool]:
|
|
service = ChatService(db, user.id)
|
|
if not service.delete_session(session_id):
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return {"ok": True}
|
|
|
|
|
|
def _collect_form_uploads(form) -> list:
|
|
uploads: list = []
|
|
seen_ids: set[int] = set()
|
|
|
|
def _append(item) -> None:
|
|
if item is None or not hasattr(item, "read"):
|
|
return
|
|
item_id = id(item)
|
|
if item_id in seen_ids:
|
|
return
|
|
seen_ids.add(item_id)
|
|
uploads.append(item)
|
|
|
|
if hasattr(form, "getlist"):
|
|
for item in form.getlist("images"):
|
|
_append(item)
|
|
single = form.get("image")
|
|
_append(single)
|
|
return uploads
|
|
|
|
|
|
async def _analyze_upload(raw: bytes, *, caption: str, user_id: int):
|
|
prepared = prepare_image(raw)
|
|
filename = save_upload(prepared, user_id=user_id)
|
|
result = await VisionService().analyze_prepared(prepared, user_hint=caption)
|
|
return result, filename
|
|
|
|
|
|
async def _parse_message_request(
|
|
request: Request,
|
|
*,
|
|
user_id: int,
|
|
) -> tuple[str, dict | None]:
|
|
content_type = (request.headers.get("content-type") or "").lower()
|
|
if "multipart/form-data" not in content_type:
|
|
try:
|
|
body = await request.json()
|
|
except json.JSONDecodeError as exc:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON body") from exc
|
|
payload = MessageCreate.model_validate(body)
|
|
return payload.content, None
|
|
|
|
form = await request.form()
|
|
caption = str(form.get("content") or "").strip()
|
|
uploads = _collect_form_uploads(form)
|
|
if not uploads:
|
|
raise HTTPException(status_code=400, detail="Field 'images' or 'image' is required for multipart upload")
|
|
|
|
max_images = max(1, int(get_settings().vision_max_images))
|
|
if len(uploads) > max_images:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Too many images (max {max_images})",
|
|
)
|
|
|
|
raw_images: list[bytes] = []
|
|
for upload in uploads:
|
|
raw = await upload.read()
|
|
if not raw:
|
|
raise HTTPException(status_code=400, detail="Empty image file")
|
|
mime = getattr(upload, "content_type", None) or "application/octet-stream"
|
|
if mime not in ALLOWED_IMAGE_TYPES:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported image type: {mime}")
|
|
raw_images.append(raw)
|
|
|
|
try:
|
|
analyzed = await asyncio.gather(
|
|
*(_analyze_upload(raw, caption=caption, user_id=user_id) for raw in raw_images)
|
|
)
|
|
except VisionUnavailableError as exc:
|
|
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
|
|
|
results = [item[0] for item in analyzed]
|
|
filenames = [item[1] for item in analyzed]
|
|
debug = vision_debug_payloads(results)
|
|
vision_text = format_user_messages(caption, results)
|
|
images_md = format_upload_images_markdown(user_id, filenames)
|
|
user_text = f"{images_md}\n\n{vision_text}" if images_md else vision_text
|
|
if not user_text.strip():
|
|
raise HTTPException(status_code=400, detail="Could not build message from image")
|
|
return user_text, debug
|
|
|
|
|
|
@router.post("/sessions/{session_id}/messages")
|
|
async def send_message(
|
|
session_id: int,
|
|
request: Request,
|
|
db: Session = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
) -> StreamingResponse:
|
|
service = ChatService(db, user.id)
|
|
if not service.get_session(session_id):
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
if is_generation_active(session_id):
|
|
raise HTTPException(status_code=409, detail="Generation already in progress")
|
|
|
|
user_text, vision_debug = await _parse_message_request(request, user_id=user.id)
|
|
|
|
service.save_user_message(session_id, user_text)
|
|
|
|
try:
|
|
handle = await start_generation(session_id, user.id, user_text)
|
|
except GenerationBusyError:
|
|
raise HTTPException(status_code=409, detail="Generation already in progress") from None
|
|
|
|
async def event_stream():
|
|
try:
|
|
if vision_debug:
|
|
yield ChatService._sse("vision", vision_debug)
|
|
async for chunk in subscribe_generation(handle):
|
|
yield chunk
|
|
except asyncio.CancelledError:
|
|
raise
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/sessions/{session_id}/context-preview")
|
|
def context_preview(
|
|
session_id: int,
|
|
query: str | None = None,
|
|
db: Session = Depends(get_db), user: User = Depends(get_current_user),
|
|
) -> dict:
|
|
service = ChatService(db, user.id)
|
|
return service.context_preview(session_id, query=query)
|