Files
2026-06-13 20:20:56 +00:00

155 lines
5.0 KiB
Python

import os
import tempfile
import uuid
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.auth.tokens import hash_token
from app.db.base import Base, get_db
from app.db.models import CharacterCard, ChatSession, MemoryFact, ShoppingList, User
@pytest.fixture()
def client():
db_path = Path(tempfile.gettempdir()) / f"test_multi_{uuid.uuid4().hex}.db"
os.environ["DATABASE_URL"] = f"sqlite:///{db_path.as_posix()}"
os.environ["DEFAULT_API_TOKEN"] = "unused-in-tests"
os.environ["AUTH_REQUIRED"] = "true"
os.environ["RAG_ENABLED"] = "false"
from app.config import get_settings
get_settings.cache_clear()
from app.main import create_app
engine = create_engine(
f"sqlite:///{db_path.as_posix()}",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
token_a = "token-user-a"
token_b = "token-user-b"
db = TestingSessionLocal()
user_a = User(
username="alice",
display_name="Alice",
api_token_hash=hash_token(token_a),
is_active=True,
)
user_b = User(
username="bob",
display_name="Bob",
api_token_hash=hash_token(token_b),
is_active=True,
)
db.add_all([user_a, user_b])
db.commit()
db.refresh(user_a)
db.refresh(user_b)
db.add(ChatSession(user_id=user_a.id, title="Alice chat"))
db.add(ChatSession(user_id=user_b.id, title="Bob chat"))
db.add(ShoppingList(user_id=user_a.id, name="groceries"))
db.add(ShoppingList(user_id=user_b.id, name="groceries"))
db.add(
CharacterCard(
user_id=user_a.id,
card_json='{"spec":"chara_card_v2","spec_version":"2.0","data":{"name":"A","rp_persona_id":"persona-a"}}',
)
)
db.add(
CharacterCard(
user_id=user_b.id,
card_json='{"spec":"chara_card_v2","spec_version":"2.0","data":{"name":"B","rp_persona_id":"persona-b"}}',
)
)
db.add(
MemoryFact(
user_id=user_a.id,
category="person",
content="Секрет только для owner",
source="test",
)
)
db.commit()
db.close()
app = create_app()
def override_get_db():
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
test_client.tokens = {"a": token_a, "b": token_b}
yield test_client
app.dependency_overrides.clear()
get_settings.cache_clear()
try:
db_path.unlink(missing_ok=True)
except OSError:
pass
def _headers(client: TestClient, who: str) -> dict[str, str]:
return {"Authorization": f"Bearer {client.tokens[who]}"}
def test_chat_sessions_isolated(client: TestClient):
res_a = client.get("/api/v1/chat/sessions", headers=_headers(client, "a"))
res_b = client.get("/api/v1/chat/sessions", headers=_headers(client, "b"))
assert res_a.status_code == 200
assert res_b.status_code == 200
titles_a = {s["title"] for s in res_a.json()}
titles_b = {s["title"] for s in res_b.json()}
assert titles_a == {"Alice chat"}
assert titles_b == {"Bob chat"}
def test_character_cards_isolated(client: TestClient):
res_a = client.get("/api/v1/character", headers=_headers(client, "a"))
res_b = client.get("/api/v1/character", headers=_headers(client, "b"))
assert res_a.json()["data"]["rp_persona_id"] == "persona-a"
assert res_b.json()["data"]["rp_persona_id"] == "persona-b"
def test_shopping_same_name_different_users(client: TestClient):
res_a = client.get("/api/v1/shopping", headers=_headers(client, "a"))
res_b = client.get("/api/v1/shopping", headers=_headers(client, "b"))
assert res_a.status_code == 200
assert res_b.status_code == 200
assert len(res_a.json()["lists"]) == 1
assert len(res_b.json()["lists"]) == 1
def test_missing_token_unauthorized(client: TestClient):
res = client.get("/api/v1/chat/sessions")
assert res.status_code == 401
def test_memory_facts_isolated(client: TestClient):
res_a = client.get("/api/v1/memory", headers=_headers(client, "a"))
res_b = client.get("/api/v1/memory", headers=_headers(client, "b"))
assert res_a.status_code == 200
assert res_b.status_code == 200
facts_a = res_a.json().get("facts") or []
facts_b = res_b.json().get("facts") or []
assert any("Секрет только для owner" in f.get("content", "") for f in facts_a)
assert not any("Секрет только для owner" in f.get("content", "") for f in facts_b)
assert res_b.json().get("total_facts", 0) == 0