Skip to content

Commit

Permalink
test: improve fixtures in unit tests for better reuse (#1638)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro authored Jul 4, 2023
1 parent 20c8ebf commit b763ede
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 85 deletions.
27 changes: 2 additions & 25 deletions tests/integration/test_studies_upgrade.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
import os
import time

import pytest
from antarest.core.tasks.model import TaskStatus
from starlette.testclient import TestClient

from antarest.core.tasks.model import TaskDTO, TaskStatus
from tests.integration.utils import wait_task_completion

RUN_ON_WINDOWS = os.name == "nt"


def wait_task_completion(
client: TestClient,
access_token: str,
task_id: str,
*,
timeout: float = 10,
) -> TaskDTO:
end_time = time.time() + timeout
while time.time() < end_time:
time.sleep(0.1)
res = client.get(
f"/v1/tasks/{task_id}",
headers={"Authorization": f"Bearer {access_token}"},
json={"wait_for_completion": True},
)
assert res.status_code == 200
task = TaskDTO(**res.json())
if task.status not in {TaskStatus.PENDING, TaskStatus.RUNNING}:
return task
raise TimeoutError(f"{timeout} seconds")


class TestStudyUpgrade:
@pytest.mark.skipif(
RUN_ON_WINDOWS, reason="This test runs randomly on Windows"
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import time
from typing import Callable

from antarest.core.tasks.model import TaskDTO, TaskStatus
from starlette.testclient import TestClient


def wait_for(
predicate: Callable[[], bool], timeout: float = 10, sleep_time: float = 1
Expand All @@ -13,3 +16,25 @@ def wait_for(
return
time.sleep(sleep_time)
raise TimeoutError(f"task is still in progress after {timeout} seconds")


def wait_task_completion(
client: TestClient,
access_token: str,
task_id: str,
*,
timeout: float = 10,
) -> TaskDTO:
end_time = time.time() + timeout
while time.time() < end_time:
time.sleep(0.1)
res = client.get(
f"/v1/tasks/{task_id}",
headers={"Authorization": f"Bearer {access_token}"},
json={"wait_for_completion": True},
)
assert res.status_code == 200
task = TaskDTO(**res.json())
if task.status not in {TaskStatus.PENDING, TaskStatus.RUNNING}:
return task
raise TimeoutError(f"{timeout} seconds")
31 changes: 25 additions & 6 deletions tests/storage/business/test_arealink_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,19 @@
from tests.storage.business.assets import ASSETS_DIR


@pytest.fixture
def empty_study(tmp_path: Path) -> FileStudy:
study_id = str(uuid.uuid4())
@pytest.fixture(name="empty_study")
def empty_study_fixture(tmp_path: Path) -> FileStudy:
"""
Fixture for preparing an empty study in the `tmp_path`
based on the "empty_study_810.zip" asset.
Args:
tmp_path: The temporary path provided by pytest.
Returns:
An instance of the `FileStudy` class representing the empty study.
"""
study_id = "5c22caca-b100-47e7-bbea-8b1b97aa26d9"
study_path = tmp_path.joinpath(study_id)
study_path.mkdir()
with ZipFile(ASSETS_DIR / "empty_study_810.zip") as zip_output:
Expand All @@ -66,9 +76,18 @@ def empty_study(tmp_path: Path) -> FileStudy:
return FileStudy(config, FileStudyTree(Mock(), config))


@pytest.fixture
def matrix_service(tmp_path: Path) -> SimpleMatrixService:
matrix_path = tmp_path.joinpath("matrix_store")
@pytest.fixture(name="matrix_service")
def matrix_service_fixture(tmp_path: Path) -> SimpleMatrixService:
"""
Fixture for creating a matrix service in the `tmp_path`.
Args:
tmp_path: The temporary path provided by pytest.
Returns:
An instance of the `SimpleMatrixService` class representing the matrix service.
"""
matrix_path = tmp_path.joinpath("matrix-store")
matrix_path.mkdir()
return SimpleMatrixService(matrix_path)

Expand Down
3 changes: 3 additions & 0 deletions tests/variantstudy/assets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pathlib import Path

ASSETS_DIR = Path(__file__).parent.resolve()
Binary file added tests/variantstudy/assets/empty_study_720.zip
Binary file not shown.
97 changes: 77 additions & 20 deletions tests/variantstudy/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import hashlib
import zipfile
from pathlib import Path
from unittest.mock import Mock

import numpy as np
import pytest
from sqlalchemy import create_engine

from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.dbmodel import Base
from antarest.matrixstore.service import MatrixService
Expand All @@ -28,29 +28,66 @@
from antarest.study.storage.variantstudy.model.command_context import (
CommandContext,
)
from sqlalchemy import create_engine
from tests.variantstudy.assets import ASSETS_DIR


@pytest.fixture
def matrix_service() -> MatrixService:
engine = create_engine("sqlite:///:memory:", echo=False)
@pytest.fixture(name="db_engine")
def db_engine_fixture():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
engine.dispose()


@pytest.fixture(name="db_middleware", autouse=True)
def db_middleware_fixture(db_engine):
# noinspection SpellCheckingInspection
DBSessionMiddleware(
yield DBSessionMiddleware(
None,
custom_engine=engine,
custom_engine=db_engine,
session_args={"autocommit": False, "autoflush": False},
)

matrix_service = Mock(spec=MatrixService)
matrix_service.create.side_effect = (
lambda data: data if isinstance(data, str) else "matrix_id"
)

@pytest.fixture(name="matrix_service")
def matrix_service_fixture() -> MatrixService:
"""
Fixture for creating a mocked matrix service.
Returns:
An instance of the `SimpleMatrixService` class representing the matrix service.
"""

def create(data):
"""
This function calculates a unique ID for each matrix, without storing
any data in the file system or the database.
"""
matrix = (
data
if isinstance(data, np.ndarray)
else np.array(data, dtype=np.float64)
)
matrix_hash = hashlib.sha256(matrix.data).hexdigest()
return matrix_hash

matrix_service = Mock(spec=MatrixService)
matrix_service.create.side_effect = create
return matrix_service


@pytest.fixture
def command_context(matrix_service: MatrixService) -> CommandContext:
@pytest.fixture(name="command_context")
def command_context_fixture(matrix_service: MatrixService) -> CommandContext:
"""
Fixture for creating a CommandContext object.
Args:
matrix_service: The MatrixService object.
Returns:
CommandContext: The CommandContext object.
"""
# sourcery skip: inline-immediately-returned-variable
command_context = CommandContext(
generator_matrix_constants=GeneratorMatrixConstants(
Expand All @@ -64,8 +101,17 @@ def command_context(matrix_service: MatrixService) -> CommandContext:
return command_context


@pytest.fixture
def command_factory(matrix_service: MatrixService) -> CommandFactory:
@pytest.fixture(name="command_factory")
def command_factory_fixture(matrix_service: MatrixService) -> CommandFactory:
"""
Fixture for creating a CommandFactory object.
Args:
matrix_service: The MatrixService object.
Returns:
CommandFactory: The CommandFactory object.
"""
return CommandFactory(
generator_matrix_constants=GeneratorMatrixConstants(
matrix_service=matrix_service
Expand All @@ -75,10 +121,21 @@ def command_factory(matrix_service: MatrixService) -> CommandFactory:
)


@pytest.fixture
def empty_study(tmp_path: Path, matrix_service: MatrixService) -> FileStudy:
project_dir: Path = Path(__file__).parent.parent.parent
empty_study_path: Path = project_dir / "resources" / "empty_study_720.zip"
@pytest.fixture(name="empty_study")
def empty_study_fixture(
tmp_path: Path, matrix_service: MatrixService
) -> FileStudy:
"""
Fixture for creating an empty FileStudy object.
Args:
tmp_path: The temporary path for extracting the empty study.
matrix_service: The MatrixService object.
Returns:
FileStudy: The empty FileStudy object.
"""
empty_study_path: Path = ASSETS_DIR / "empty_study_720.zip"
empty_study_destination_path = tmp_path.joinpath("empty-study")
with zipfile.ZipFile(empty_study_path, "r") as zip_empty_study:
zip_empty_study.extractall(empty_study_destination_path)
Expand All @@ -87,7 +144,7 @@ def empty_study(tmp_path: Path, matrix_service: MatrixService) -> FileStudy:
study_path=empty_study_destination_path,
path=empty_study_destination_path,
study_id="",
version=700,
version=720,
areas={},
sets={},
)
Expand Down
4 changes: 3 additions & 1 deletion tests/variantstudy/model/command/test_create_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def test_match(command_context: CommandContext):
assert not base.match(other_not_match)
assert not base.match(other_other)
assert base.match_signature() == "create_cluster%foo%foo"
assert base.get_inner_matrices() == ["matrix_id", "matrix_id"]
# check the matrices links
matrix_id = command_context.matrix_service.create([[0]])
assert base.get_inner_matrices() == [matrix_id, matrix_id]


def test_revert(command_context: CommandContext):
Expand Down
4 changes: 3 additions & 1 deletion tests/variantstudy/model/command/test_create_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def test_match(command_context: CommandContext):
assert not base.match(other_not_match)
assert not base.match(other_other)
assert base.match_signature() == "create_link%foo%bar"
assert base.get_inner_matrices() == ["matrix_id"]
# check the matrices links
matrix_id = command_context.matrix_service.create([[0]])
assert base.get_inner_matrices() == [matrix_id]


def test_revert(command_context: CommandContext):
Expand Down
Loading

0 comments on commit b763ede

Please sign in to comment.