155 lines
5.0 KiB
Python
155 lines
5.0 KiB
Python
import os
|
|
import tempfile
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.auth.tokens import hash_token
|
|
from app.db.base import Base, get_db
|
|
from app.db.models import CharacterCard, ChatSession, MemoryFact, ShoppingList, User
|
|
|
|
|
|
@pytest.fixture()
|
|
def client():
|
|
db_path = Path(tempfile.gettempdir()) / f"test_multi_{uuid.uuid4().hex}.db"
|
|
os.environ["DATABASE_URL"] = f"sqlite:///{db_path.as_posix()}"
|
|
os.environ["DEFAULT_API_TOKEN"] = "unused-in-tests"
|
|
os.environ["AUTH_REQUIRED"] = "true"
|
|
os.environ["RAG_ENABLED"] = "false"
|
|
|
|
from app.config import get_settings
|
|
|
|
get_settings.cache_clear()
|
|
|
|
from app.main import create_app
|
|
|
|
engine = create_engine(
|
|
f"sqlite:///{db_path.as_posix()}",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
token_a = "token-user-a"
|
|
token_b = "token-user-b"
|
|
|
|
db = TestingSessionLocal()
|
|
user_a = User(
|
|
username="alice",
|
|
display_name="Alice",
|
|
api_token_hash=hash_token(token_a),
|
|
is_active=True,
|
|
)
|
|
user_b = User(
|
|
username="bob",
|
|
display_name="Bob",
|
|
api_token_hash=hash_token(token_b),
|
|
is_active=True,
|
|
)
|
|
db.add_all([user_a, user_b])
|
|
db.commit()
|
|
db.refresh(user_a)
|
|
db.refresh(user_b)
|
|
|
|
db.add(ChatSession(user_id=user_a.id, title="Alice chat"))
|
|
db.add(ChatSession(user_id=user_b.id, title="Bob chat"))
|
|
db.add(ShoppingList(user_id=user_a.id, name="groceries"))
|
|
db.add(ShoppingList(user_id=user_b.id, name="groceries"))
|
|
db.add(
|
|
CharacterCard(
|
|
user_id=user_a.id,
|
|
card_json='{"spec":"chara_card_v2","spec_version":"2.0","data":{"name":"A","rp_persona_id":"persona-a"}}',
|
|
)
|
|
)
|
|
db.add(
|
|
CharacterCard(
|
|
user_id=user_b.id,
|
|
card_json='{"spec":"chara_card_v2","spec_version":"2.0","data":{"name":"B","rp_persona_id":"persona-b"}}',
|
|
)
|
|
)
|
|
db.add(
|
|
MemoryFact(
|
|
user_id=user_a.id,
|
|
category="person",
|
|
content="Секрет только для owner",
|
|
source="test",
|
|
)
|
|
)
|
|
db.commit()
|
|
db.close()
|
|
|
|
app = create_app()
|
|
|
|
def override_get_db():
|
|
db = TestingSessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
with TestClient(app) as test_client:
|
|
test_client.tokens = {"a": token_a, "b": token_b}
|
|
yield test_client
|
|
|
|
app.dependency_overrides.clear()
|
|
get_settings.cache_clear()
|
|
try:
|
|
db_path.unlink(missing_ok=True)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def _headers(client: TestClient, who: str) -> dict[str, str]:
|
|
return {"Authorization": f"Bearer {client.tokens[who]}"}
|
|
|
|
|
|
def test_chat_sessions_isolated(client: TestClient):
|
|
res_a = client.get("/api/v1/chat/sessions", headers=_headers(client, "a"))
|
|
res_b = client.get("/api/v1/chat/sessions", headers=_headers(client, "b"))
|
|
assert res_a.status_code == 200
|
|
assert res_b.status_code == 200
|
|
titles_a = {s["title"] for s in res_a.json()}
|
|
titles_b = {s["title"] for s in res_b.json()}
|
|
assert titles_a == {"Alice chat"}
|
|
assert titles_b == {"Bob chat"}
|
|
|
|
|
|
def test_character_cards_isolated(client: TestClient):
|
|
res_a = client.get("/api/v1/character", headers=_headers(client, "a"))
|
|
res_b = client.get("/api/v1/character", headers=_headers(client, "b"))
|
|
assert res_a.json()["data"]["rp_persona_id"] == "persona-a"
|
|
assert res_b.json()["data"]["rp_persona_id"] == "persona-b"
|
|
|
|
|
|
def test_shopping_same_name_different_users(client: TestClient):
|
|
res_a = client.get("/api/v1/shopping", headers=_headers(client, "a"))
|
|
res_b = client.get("/api/v1/shopping", headers=_headers(client, "b"))
|
|
assert res_a.status_code == 200
|
|
assert res_b.status_code == 200
|
|
assert len(res_a.json()["lists"]) == 1
|
|
assert len(res_b.json()["lists"]) == 1
|
|
|
|
|
|
def test_missing_token_unauthorized(client: TestClient):
|
|
res = client.get("/api/v1/chat/sessions")
|
|
assert res.status_code == 401
|
|
|
|
|
|
def test_memory_facts_isolated(client: TestClient):
|
|
res_a = client.get("/api/v1/memory", headers=_headers(client, "a"))
|
|
res_b = client.get("/api/v1/memory", headers=_headers(client, "b"))
|
|
assert res_a.status_code == 200
|
|
assert res_b.status_code == 200
|
|
facts_a = res_a.json().get("facts") or []
|
|
facts_b = res_b.json().get("facts") or []
|
|
assert any("Секрет только для owner" in f.get("content", "") for f in facts_a)
|
|
assert not any("Секрет только для owner" in f.get("content", "") for f in facts_b)
|
|
assert res_b.json().get("total_facts", 0) == 0
|