52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
|
|
class SseChunk:
|
|
__slots__ = ("event", "data")
|
|
|
|
def __init__(self, event: str, data: dict[str, Any]) -> None:
|
|
self.event = event
|
|
self.data = data
|
|
|
|
|
|
def _parse_sse_part(part: str) -> SseChunk | None:
|
|
if not part.strip():
|
|
return None
|
|
event = "message"
|
|
data = ""
|
|
for line in part.split("\n"):
|
|
if line.startswith("event: "):
|
|
event = line[7:]
|
|
elif line.startswith("data: "):
|
|
data = line[6:]
|
|
if not data:
|
|
return None
|
|
return SseChunk(event=event, data=json.loads(data))
|
|
|
|
|
|
async def iter_sse(response: httpx.Response) -> AsyncIterator[SseChunk]:
|
|
if response.status_code >= 400:
|
|
detail = (await response.aread()).decode("utf-8", errors="replace")
|
|
raise RuntimeError(detail or f"HTTP {response.status_code}")
|
|
|
|
buffer = ""
|
|
async for chunk in response.aiter_text():
|
|
buffer += chunk
|
|
parts = buffer.split("\n\n")
|
|
buffer = parts.pop() if parts else ""
|
|
for part in parts:
|
|
parsed = _parse_sse_part(part)
|
|
if parsed:
|
|
yield parsed
|
|
|
|
if buffer.strip():
|
|
parsed = _parse_sse_part(buffer)
|
|
if parsed:
|
|
yield parsed
|