Skip to content

Commit ad863fd

Browse files
authored
Pytest speedup (#849)
* add code to measure time spent in pytest * speed up datapipe tests * fix cleanup of dist vars (was causing slowdown in test_capture.py) * speed up model tests * bring back some parameterizations, reduced cpu tests
1 parent 4de8534 commit ad863fd

17 files changed

+82
-52
lines changed

test/conftest.py

+22
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,30 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from collections import defaultdict
18+
1719
import pytest
1820

21+
file_timings = defaultdict(float)
22+
23+
# Total time per file
24+
file_timings = defaultdict(float)
25+
26+
27+
def pytest_runtest_logreport(report):
28+
if report.when == "call":
29+
# report.nodeid format: path::TestClass::test_name
30+
filename = report.nodeid.split("::")[0]
31+
file_timings[filename] += report.duration
32+
33+
34+
def pytest_sessionfinish(session, exitstatus):
35+
print("\n=== Test durations by file ===")
36+
for filename, duration in sorted(
37+
file_timings.items(), key=lambda x: x[1], reverse=True
38+
):
39+
print(f"{filename}: {duration:.2f} seconds")
40+
1941

2042
def pytest_addoption(parser):
2143
parser.addoption(

test/datapipes/test_climate.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def geopotential_filename():
7474
shuffle=False,
7575
)
7676

77-
77+
# Skip CPU tests because too slow
7878
@nfsdata_or_fail
7979
@import_or_fail("netCDF4")
80-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
80+
@pytest.mark.parametrize("device", ["cuda:0"])
8181
def test_climate_hdf5_constructor(
8282
data_dir,
8383
stats_files,
@@ -230,12 +230,13 @@ def test_climate_hdf5_device(
230230
break
231231

232232

233+
# Skip CPU tests because too slow
233234
@nfsdata_or_fail
234235
@import_or_fail("netCDF4")
235236
@pytest.mark.parametrize("data_channels", [[0, 1]])
236237
@pytest.mark.parametrize("num_steps", [2])
237-
@pytest.mark.parametrize("batch_size", [1, 2, 3])
238-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
238+
@pytest.mark.parametrize("batch_size", [2, 3])
239+
@pytest.mark.parametrize("device", ["cuda:0"])
239240
def test_climate_hdf5_shape(
240241
data_dir,
241242
stats_files,
@@ -322,11 +323,12 @@ def test_climate_hdf5_shape(
322323
break
323324

324325

326+
# Skip CPU tests because too slow
325327
@nfsdata_or_fail
326328
@import_or_fail("netCDF4")
327329
@pytest.mark.parametrize("num_steps", [1, 2])
328330
@pytest.mark.parametrize("stride", [1, 3])
329-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
331+
@pytest.mark.parametrize("device", ["cuda:0"])
330332
def test_era5_hdf5_sequence(
331333
data_dir,
332334
stats_files,
@@ -375,6 +377,7 @@ def test_era5_hdf5_sequence(
375377
)
376378

377379

380+
# Skip CPU tests because too slow
378381
@nfsdata_or_fail
379382
@import_or_fail("netCDF4")
380383
@pytest.mark.parametrize("shuffle", [True, False])

test/datapipes/test_mesh_datapipe.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _create_random_vtp_vtu_mesh(
118118

119119
tmp_dir = tmp_path / "temp_data"
120120
tmp_dir.mkdir()
121-
_create_random_vtp_vtu_mesh(num_points=20, num_triangles=40, dir=tmp_dir)
121+
_create_random_vtp_vtu_mesh(num_points=10, num_triangles=20, dir=tmp_dir)
122122
datapipe_vtp = MeshDatapipe(
123123
data_dir=tmp_dir,
124124
variables=["RandomFeatures"],
@@ -134,8 +134,8 @@ def _create_random_vtp_vtu_mesh(
134134

135135
assert len(datapipe_vtp) == 1
136136
for data in datapipe_vtp:
137-
assert data[0]["vertices"].shape == (1, 20, 3)
138-
assert data[0]["x"].shape == (1, 20, 1)
137+
assert data[0]["vertices"].shape == (1, 10, 3)
138+
assert data[0]["x"].shape == (1, 10, 1)
139139

140140
datapipe_vtu = MeshDatapipe(
141141
data_dir=tmp_dir,
@@ -152,8 +152,8 @@ def _create_random_vtp_vtu_mesh(
152152

153153
assert len(datapipe_vtu) == 1
154154
for data in datapipe_vtu:
155-
assert data[0]["vertices"].shape == (1, 20, 3)
156-
assert data[0]["x"].shape == (1, 20, 1)
155+
assert data[0]["vertices"].shape == (1, 10, 3)
156+
assert data[0]["x"].shape == (1, 10, 1)
157157

158158

159159
# @nfsdata_or_fail

test/metrics/test_metrics_general.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def get_disagreements(inputs, bins, counts, test):
6464
print("True counts", trueh)
6565

6666

67-
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
68-
@pytest.mark.parametrize("input_shape", [(1, 72, 144), (1, 360, 720)])
67+
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
68+
@pytest.mark.parametrize("input_shape", [(1, 72, 144)])
6969
def test_histogram(device, input_shape, rtol: float = 1e-3, atol: float = 1e-3):
7070
DistributedManager._shared_state = {}
7171
if (device == "cuda:0") and (not DistributedManager.is_initialized()):
@@ -225,6 +225,10 @@ def test_histogram(device, input_shape, rtol: float = 1e-3, atol: float = 1e-3):
225225
)
226226
if device == "cuda:0":
227227
DistributedManager.cleanup()
228+
del os.environ["RANK"]
229+
del os.environ["WORLD_SIZE"]
230+
del os.environ["MASTER_ADDR"]
231+
del os.environ["MASTER_PORT"]
228232

229233

230234
def fair_crps(pred, obs, dim=-1):
@@ -539,8 +543,8 @@ def test_crps(device, rtol: float = 1e-3, atol: float = 1e-3):
539543

540544

541545
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
542-
@pytest.mark.parametrize("mean", [0.0, 3.0])
543-
@pytest.mark.parametrize("variance", [1.0, 0.1, 3.0])
546+
@pytest.mark.parametrize("mean", [3.0])
547+
@pytest.mark.parametrize("variance", [0.1])
544548
def test_wasserstein(device, mean, variance, rtol: float = 1e-3, atol: float = 1e-3):
545549
mean = torch.as_tensor([mean], device=device, dtype=torch.float32)
546550
variance = torch.as_tensor([variance], device=device, dtype=torch.float32)
@@ -704,6 +708,10 @@ def test_means_var(device, rtol: float = 1e-3, atol: float = 1e-3):
704708

705709
if device == "cuda:0":
706710
DistributedManager.cleanup()
711+
del os.environ["RANK"]
712+
del os.environ["WORLD_SIZE"]
713+
del os.environ["MASTER_ADDR"]
714+
del os.environ["MASTER_PORT"]
707715

708716

709717
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
@@ -781,7 +789,7 @@ def test_calibration(device, rtol: float = 1e-2, atol: float = 1e-2):
781789
def test_entropy(device, rtol: float = 1e-2, atol: float = 1e-2):
782790
one = torch.ones([1], device=device, dtype=torch.float32)
783791

784-
x = torch.randn((100_000, 10, 10), device=device, dtype=torch.float32)
792+
x = torch.randn((50_000, 10, 10), device=device, dtype=torch.float32)
785793
bin_edges, bin_counts = hist.histogram(x, bins=30)
786794
entropy = ent.entropy_from_counts(bin_counts, bin_edges, normalized=False)
787795
assert entropy.shape == (10, 10)
@@ -810,11 +818,11 @@ def test_entropy(device, rtol: float = 1e-2, atol: float = 1e-2):
810818
assert torch.allclose(entropy, one, rtol=rtol, atol=atol)
811819

812820
# Test Relative Entropy
813-
x = torch.randn((500_000, 10, 10), device=device, dtype=torch.float32)
821+
x = torch.randn((100_000, 10, 10), device=device, dtype=torch.float32)
814822
bin_edges, x_bin_counts = hist.histogram(x, bins=30)
815-
x1 = torch.randn((500_000, 10, 10), device=device, dtype=torch.float32)
823+
x1 = torch.randn((100_000, 10, 10), device=device, dtype=torch.float32)
816824
_, x1_bin_counts = hist.histogram(x1, bins=bin_edges)
817-
x2 = 0.1 * torch.randn((100_000, 10, 10), device=device, dtype=torch.float32)
825+
x2 = 0.1 * torch.randn((50_000, 10, 10), device=device, dtype=torch.float32)
818826
_, x2_bin_counts = hist.histogram(x2, bins=bin_edges)
819827

820828
rel_ent_1 = ent.relative_entropy_from_counts(x_bin_counts, x1_bin_counts, bin_edges)
@@ -847,8 +855,8 @@ def test_entropy(device, rtol: float = 1e-2, atol: float = 1e-2):
847855

848856
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
849857
def test_power_spectrum(device):
850-
"""Test the 2D power spectrum routine for correctness using a sine wave"""
851-
h, w = 64, 64
858+
# Test the 2D power spectrum routine for correctness using a sine wave
859+
h, w = 32, 32
852860
kx, ky = 4, 4
853861
amplitude = 1.0
854862

test/models/data/dlwp_healpix.pth

-960 KB
Binary file not shown.
-960 KB
Binary file not shown.
-960 KB
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
Binary file not shown.

test/models/diffusion/test_dhariwal_unet.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,20 @@ def test_dhariwal_unet_constructor(device):
6363
assert output_image.shape == (1, out_channels, img_resolution, img_resolution)
6464

6565

66-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
66+
# Skip CPU tests because too slow
67+
@pytest.mark.parametrize("device", ["cuda:0"])
6768
def test_dhariwal_unet_optims(device):
6869
"""Test Dhariwal UNet optimizations"""
6970

7071
def setup_model():
7172
model = UNet(
72-
img_resolution=16,
73+
img_resolution=8,
7374
in_channels=2,
7475
out_channels=2,
7576
).to(device)
7677
noise_labels = torch.randn([1]).to(device)
7778
class_labels = torch.randint(0, 1, (1, 1)).to(device)
78-
input_image = torch.ones([1, 2, 16, 16]).to(device)
79+
input_image = torch.ones([1, 2, 8, 8]).to(device)
7980

8081
return model, [input_image, noise_labels, class_labels]
8182

@@ -94,7 +95,8 @@ def setup_model():
9495
assert common.validate_combo_optims(model, (*invar,))
9596

9697

97-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
98+
# Skip CPU tests because too slow
99+
@pytest.mark.parametrize("device", ["cuda:0"])
98100
def test_dhariwal_unet_checkpoint(device):
99101
"""Test Dhariwal UNet checkpoint save/load"""
100102
# Construct FNO models
@@ -113,7 +115,7 @@ def test_dhariwal_unet_checkpoint(device):
113115
# Change the bias in the last layer of the second model as a hack
114116
# Because this model is initialized with all zeros
115117
with torch.no_grad():
116-
model_2.out_conv.bias += 1
118+
model_2.out_conv.bias.add_(1)
117119

118120
noise_labels = torch.randn([1]).to(device)
119121
class_labels = torch.randint(0, 1, (1, 1)).to(device)

test/models/diffusion/test_song_unet_pos_embd.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ def test_fails_if_grid_is_invalid():
243243
)
244244

245245

246-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
246+
# Skip CPU tests because too slow
247+
@pytest.mark.parametrize("device", ["cuda:0"])
247248
def test_song_unet_optims(device):
248249
"""Test Song UNet optimizations"""
249250

@@ -278,7 +279,8 @@ def setup_model():
278279
assert common.validate_combo_optims(model, (*invar,))
279280

280281

281-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
282+
# Skip CPU tests because too slow
283+
@pytest.mark.parametrize("device", ["cuda:0"])
282284
def test_song_unet_checkpoint(device):
283285
"""Test Song UNet checkpoint save/load"""
284286
# Construct FNO models

test/models/diffusion/test_song_unet_pos_lt_embd.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ def test_fails_if_grid_is_invalid():
324324
)
325325

326326

327-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
327+
# Skip CPU tests because too slow
328+
@pytest.mark.parametrize("device", ["cuda:0"])
328329
def test_song_unet_optims(device):
329330
"""Test Song UNet optimizations"""
330331

@@ -359,7 +360,8 @@ def setup_model():
359360
assert common.validate_combo_optims(model, (*invar,))
360361

361362

362-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
363+
# Skip CPU tests because too slow
364+
@pytest.mark.parametrize("device", ["cuda:0"])
363365
def test_song_unet_checkpoint(device):
364366
"""Test Song UNet checkpoint save/load"""
365367
# Construct FNO models

test/models/dlwp_healpix/test_healpix_recunet_model.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ def generate_insolation_data(batch_size=8, time_dim=1, img_size=16, device="cpu"
172172
@import_or_fail("omegaconf")
173173
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
174174
def test_HEALPixRecUNet_initialize(device, encoder_dict, decoder_dict, pytestconfig):
175-
in_channels = 7
176-
out_channels = 7
175+
in_channels = 3
176+
out_channels = 3
177177
n_constants = 1
178178
decoder_input_channels = 1
179179
input_time_dim = 2
@@ -314,8 +314,8 @@ def test_HEALPixRecUNet_reset(
314314
pytestconfig,
315315
):
316316
# create a smaller version of the dlwp healpix model
317-
in_channels = 3
318-
out_channels = 3
317+
in_channels = 2
318+
out_channels = 2
319319
n_constants = 2
320320
decoder_input_channels = 1
321321
input_time_dim = 2
@@ -366,13 +366,13 @@ def test_HEALPixRecUNet_forward(
366366
pytestconfig,
367367
):
368368
# create a smaller version of the dlwp healpix model
369-
in_channels = 3
370-
out_channels = 3
369+
in_channels = 2
370+
out_channels = 2
371371
n_constants = 2
372372
decoder_input_channels = 1
373373
input_time_dim = 2
374374
output_time_dim = 4
375-
batch_size = 8
375+
batch_size = 2
376376
size = 16
377377

378378
fix_random_seeds(seed=42)

test/models/test_dlwp.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def test_dlwp_forward(device):
4747

4848

4949
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
50-
@pytest.mark.parametrize("nr_input_channels", [2, 4])
51-
@pytest.mark.parametrize("nr_output_channels", [2, 4])
52-
@pytest.mark.parametrize("nr_initial_channels", [32, 64])
53-
@pytest.mark.parametrize("depth", [2, 3, 4])
50+
@pytest.mark.parametrize("nr_input_channels", [2])
51+
@pytest.mark.parametrize("nr_output_channels", [2])
52+
@pytest.mark.parametrize("nr_initial_channels", [32])
53+
@pytest.mark.parametrize("depth", [2])
5454
def test_dlwp_constructor(
5555
device, nr_input_channels, nr_output_channels, nr_initial_channels, depth
5656
):

test/models/test_swinrnn.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from . import common
2525

2626

27-
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
27+
# Skip CPU tests because too slow
28+
@pytest.mark.parametrize("device", ["cuda:0"])
2829
def test_swinrnn_forward(device):
2930
"""Test SwinRNN forward pass"""
3031
torch.manual_seed(0)
@@ -43,7 +44,7 @@ def test_swinrnn_forward(device):
4344
invar = torch.randn(bsize, 13, 6, 32, 64).to(device)
4445
# Check output size
4546
with torch.no_grad():
46-
assert common.validate_forward_accuracy(model, (invar,), atol=5e-3)
47+
assert common.validate_forward_accuracy(model, (invar,), atol=5e-3, rtol=1e-3)
4748
del invar, model
4849
torch.cuda.empty_cache()
4950

@@ -53,16 +54,6 @@ def test_swinrnn_constructor(device):
5354
"""Test SwinRNN constructor options"""
5455
# Define dictionary of constructor args
5556
arg_list = [
56-
{
57-
"img_size": (6, 32, 64),
58-
"patch_size": (6, 1, 1),
59-
"in_chans": 13,
60-
"out_chans": 13,
61-
"embed_dim": 768,
62-
"num_groups": 32,
63-
"num_heads": 8,
64-
"window_size": 8,
65-
},
6657
{
6758
"img_size": (3, 32, 32),
6859
"patch_size": (3, 1, 1),

0 commit comments

Comments
 (0)