import asyncio import json from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from app.api.chat_schemas import GenerationStatusOut, MessagesPageOut from app.api.schemas import ( MessageCreate, SessionCreate, SessionDetailOut, SessionOut, ) from app.auth.deps import get_current_user from app.chat.generation import ( GenerationBusyError, get_active_handle, is_generation_active, start_generation, subscribe_generation, ) from app.chat.service import ChatService from app.config import get_settings from app.db.base import get_db from app.db.models import User from app.vision import VisionService, format_user_messages, vision_debug_payloads from app.vision.analyze import VisionUnavailableError from app.vision.preprocess import prepare_image from app.vision.storage import format_upload_images_markdown, save_upload router = APIRouter() ALLOWED_IMAGE_TYPES = {"image/jpeg", "image/png", "image/webp", "image/gif"} @router.post("/sessions", response_model=SessionOut) def create_session(payload: SessionCreate, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionOut: service = ChatService(db, user.id) return service.create_session(title=payload.title) @router.get("/sessions", response_model=list[SessionOut]) def list_sessions(db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> list[SessionOut]: service = ChatService(db, user.id) return service.list_sessions() @router.get("/sessions/{session_id}", response_model=SessionDetailOut) def get_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> SessionDetailOut: service = ChatService(db, user.id) session = service.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") return session @router.get("/sessions/{session_id}/messages", response_model=MessagesPageOut) def list_messages( session_id: int, limit: int = 30, before_id: int | None = None, after_id: int | None = None, db: Session = Depends(get_db), user: User = Depends(get_current_user), ) -> MessagesPageOut: service = ChatService(db, user.id) if not service.get_session(session_id): raise HTTPException(status_code=404, detail="Session not found") messages, has_more = service.list_messages( session_id, limit=min(max(limit, 1), 100), before_id=before_id, after_id=after_id, ) return MessagesPageOut(messages=messages, has_more=has_more) @router.get("/sessions/{session_id}/generation", response_model=GenerationStatusOut) def generation_status(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> GenerationStatusOut: service = ChatService(db, user.id) if not service.get_session(session_id): raise HTTPException(status_code=404, detail="Session not found") return GenerationStatusOut(active=is_generation_active(session_id)) @router.get("/sessions/{session_id}/generation/stream") async def generation_stream(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> StreamingResponse: service = ChatService(db, user.id) if not service.get_session(session_id): raise HTTPException(status_code=404, detail="Session not found") handle = get_active_handle(session_id) if not handle: raise HTTPException(status_code=404, detail="No active generation") async def event_stream(): async for chunk in subscribe_generation(handle): yield chunk return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.delete("/sessions/{session_id}") def delete_session(session_id: int, db: Session = Depends(get_db), user: User = Depends(get_current_user)) -> dict[str, bool]: service = ChatService(db, user.id) if not service.delete_session(session_id): raise HTTPException(status_code=404, detail="Session not found") return {"ok": True} def _collect_form_uploads(form) -> list: uploads: list = [] seen_ids: set[int] = set() def _append(item) -> None: if item is None or not hasattr(item, "read"): return item_id = id(item) if item_id in seen_ids: return seen_ids.add(item_id) uploads.append(item) if hasattr(form, "getlist"): for item in form.getlist("images"): _append(item) single = form.get("image") _append(single) return uploads async def _analyze_upload(raw: bytes, *, caption: str, user_id: int): prepared = prepare_image(raw) filename = save_upload(prepared, user_id=user_id) result = await VisionService().analyze_prepared(prepared, user_hint=caption) return result, filename async def _parse_message_request( request: Request, *, user_id: int, ) -> tuple[str, dict | None]: content_type = (request.headers.get("content-type") or "").lower() if "multipart/form-data" not in content_type: try: body = await request.json() except json.JSONDecodeError as exc: raise HTTPException(status_code=400, detail="Invalid JSON body") from exc payload = MessageCreate.model_validate(body) return payload.content, None form = await request.form() caption = str(form.get("content") or "").strip() uploads = _collect_form_uploads(form) if not uploads: raise HTTPException(status_code=400, detail="Field 'images' or 'image' is required for multipart upload") max_images = max(1, int(get_settings().vision_max_images)) if len(uploads) > max_images: raise HTTPException( status_code=400, detail=f"Too many images (max {max_images})", ) raw_images: list[bytes] = [] for upload in uploads: raw = await upload.read() if not raw: raise HTTPException(status_code=400, detail="Empty image file") mime = getattr(upload, "content_type", None) or "application/octet-stream" if mime not in ALLOWED_IMAGE_TYPES: raise HTTPException(status_code=400, detail=f"Unsupported image type: {mime}") raw_images.append(raw) try: analyzed = await asyncio.gather( *(_analyze_upload(raw, caption=caption, user_id=user_id) for raw in raw_images) ) except VisionUnavailableError as exc: raise HTTPException(status_code=502, detail=str(exc)) from exc results = [item[0] for item in analyzed] filenames = [item[1] for item in analyzed] debug = vision_debug_payloads(results) vision_text = format_user_messages(caption, results) images_md = format_upload_images_markdown(user_id, filenames) user_text = f"{images_md}\n\n{vision_text}" if images_md else vision_text if not user_text.strip(): raise HTTPException(status_code=400, detail="Could not build message from image") return user_text, debug @router.post("/sessions/{session_id}/messages") async def send_message( session_id: int, request: Request, db: Session = Depends(get_db), user: User = Depends(get_current_user), ) -> StreamingResponse: service = ChatService(db, user.id) if not service.get_session(session_id): raise HTTPException(status_code=404, detail="Session not found") if is_generation_active(session_id): raise HTTPException(status_code=409, detail="Generation already in progress") user_text, vision_debug = await _parse_message_request(request, user_id=user.id) service.save_user_message(session_id, user_text) try: handle = await start_generation(session_id, user.id, user_text) except GenerationBusyError: raise HTTPException(status_code=409, detail="Generation already in progress") from None async def event_stream(): try: if vision_debug: yield ChatService._sse("vision", vision_debug) async for chunk in subscribe_generation(handle): yield chunk except asyncio.CancelledError: raise return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.get("/sessions/{session_id}/context-preview") def context_preview( session_id: int, query: str | None = None, db: Session = Depends(get_db), user: User = Depends(get_current_user), ) -> dict: service = ChatService(db, user.id) return service.context_preview(session_id, query=query)