diff --git a/tests/integration/defs/conftest.py b/tests/integration/defs/conftest.py index 2490fdab24..ee5a75ee22 100644 --- a/tests/integration/defs/conftest.py +++ b/tests/integration/defs/conftest.py @@ -1817,12 +1817,18 @@ def skip_by_mpi_world_size(request): "fixture for skip less device count" if request.node.get_closest_marker('skip_less_mpi_world_size'): mpi_world_size = get_mpi_world_size() + device_count = get_device_count() + if mpi_world_size == 1: + # For mpi_world_size == 1 case, we only need to check device count since we can spawn mpi workers in the test itself + total_count = device_count + else: + # Otherwise, we follow the mpi world size setting + total_count = mpi_world_size expected_count = request.node.get_closest_marker( 'skip_less_mpi_world_size').args[0] - if expected_count > int(mpi_world_size): + if expected_count > int(total_count): pytest.skip( - f'MPI world size {mpi_world_size} is less than {expected_count}' - ) + f'Total world size {total_count} is less than {expected_count}') @pytest.fixture(autouse=True)