|
| 1 | +import os |
| 2 | +from pathlib import Path |
| 3 | +from unittest.mock import patch |
| 4 | + |
1 | 5 | import pytest |
2 | 6 | from fastapi.testclient import TestClient |
3 | | -from sqlalchemy import create_engine |
4 | | -from sqlalchemy.orm import sessionmaker |
5 | | -from sqlalchemy.pool import StaticPool |
| 7 | +from sqlalchemy.ext.asyncio import async_sessionmaker |
6 | 8 |
|
7 | 9 | from fastapi_app import create_app |
8 | | -from fastapi_app.postgres_models import Base |
9 | | - |
10 | | -POSTGRESQL_DATABASE_URL = "postgresql://admin:postgres@localhost:5432/postgres" |
11 | | - |
12 | | - |
13 | | -# Create a SQLAlchemy engine |
14 | | -engine = create_engine( |
15 | | - POSTGRESQL_DATABASE_URL, |
16 | | - poolclass=StaticPool, |
| 10 | +from fastapi_app.globals import global_storage |
| 11 | + |
| 12 | +POSTGRES_HOST = "localhost" |
| 13 | +POSTGRES_USERNAME = "admin" |
| 14 | +POSTGRES_DATABASE = "postgres" |
| 15 | +POSTGRES_PASSWORD = "postgres" |
| 16 | +POSTGRES_SSL = "prefer" |
| 17 | +POSTGRESQL_DATABASE_URL = ( |
| 18 | + f"postgresql+asyncpg://{POSTGRES_USERNAME}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}/{POSTGRES_DATABASE}" |
17 | 19 | ) |
18 | 20 |
|
19 | | -# Create a sessionmaker to manage sessions |
20 | | -TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
| 21 | + |
| 22 | +@pytest.fixture(scope="session") |
| 23 | +def setup_env(): |
| 24 | + os.environ["POSTGRES_HOST"] = POSTGRES_HOST |
| 25 | + os.environ["POSTGRES_USERNAME"] = POSTGRES_USERNAME |
| 26 | + os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE |
| 27 | + os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD |
| 28 | + os.environ["POSTGRES_SSL"] = POSTGRES_SSL |
| 29 | + os.environ["POSTGRESQL_DATABASE_URL"] = POSTGRESQL_DATABASE_URL |
| 30 | + os.environ["RUNNING_IN_PRODUCTION"] = "False" |
| 31 | + os.environ["OPENAI_API_KEY"] = "fakekey" |
21 | 32 |
|
22 | 33 |
|
23 | 34 | @pytest.fixture(scope="session") |
24 | | -def setup_database(): |
25 | | - """Create tables in the database for all tests.""" |
26 | | - try: |
27 | | - Base.metadata.create_all(bind=engine) |
| 35 | +def mock_azure_credential(): |
| 36 | + """Mock the Azure credential for testing.""" |
| 37 | + with patch("azure.identity.DefaultAzureCredential", return_value=None): |
28 | 38 | yield |
29 | | - Base.metadata.drop_all(bind=engine) |
30 | | - except Exception as e: |
31 | | - pytest.skip(f"Unable to connect to the database: {e}") |
32 | 39 |
|
33 | 40 |
|
34 | 41 | @pytest.fixture(scope="session") |
35 | | -def app(): |
| 42 | +def app(setup_env, mock_azure_credential): |
36 | 43 | """Create a FastAPI app.""" |
37 | | - return create_app() |
38 | | - |
39 | | - |
40 | | -@pytest.fixture(scope="function") |
41 | | -def db_session(setup_database): |
42 | | - """Create a new database session with a rollback at the end of the test.""" |
43 | | - connection = engine.connect() |
44 | | - transaction = connection.begin() |
45 | | - session = TestingSessionLocal(bind=connection) |
46 | | - yield session |
47 | | - session.close() |
48 | | - transaction.rollback() |
49 | | - connection.close() |
| 44 | + if not Path("src/static/").exists(): |
| 45 | + pytest.skip("Please generate frontend files first!") |
| 46 | + return create_app(is_testing=True) |
50 | 47 |
|
51 | 48 |
|
52 | 49 | @pytest.fixture(scope="function") |
53 | | -def test_db_client(app, db_session): |
54 | | - """Create a test client that uses the override_get_db fixture to return a session.""" |
55 | | - |
56 | | - def override_db_session(): |
57 | | - try: |
58 | | - yield db_session |
59 | | - finally: |
60 | | - db_session.close() |
| 50 | +def test_client(app): |
| 51 | + """Create a test client.""" |
61 | 52 |
|
62 | | - app.router.lifespan = override_db_session |
63 | 53 | with TestClient(app) as test_client: |
64 | 54 | yield test_client |
65 | 55 |
|
66 | 56 |
|
67 | | -@pytest.fixture(scope="session") |
68 | | -def test_client(app): |
69 | | - """Create a test client.""" |
70 | | - with TestClient(app) as test_client: |
71 | | - yield test_client |
| 57 | +@pytest.fixture(scope="function") |
| 58 | +def db_session(): |
| 59 | + """Create a new database session with a rollback at the end of the test.""" |
| 60 | + async_sesion = async_sessionmaker(autocommit=False, autoflush=False, bind=global_storage.engine) |
| 61 | + session = async_sesion() |
| 62 | + session.begin() |
| 63 | + yield session |
| 64 | + session.rollback() |
| 65 | + session.close() |
0 commit comments