Skip to content

Commit ce8e69f

Browse files
committed
Use "cpu" instead of "mps" when running on GitHub Actions.
1 parent d09fbed commit ce8e69f

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

tests/unit/test_torch.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from pathlib import Path
2+
import os
3+
import sys
24

35
import numpy as np
46
import pytest
@@ -54,6 +56,14 @@ def infer_config_from_state_dict(cls, state_dict):
5456
return {"input_size": state_dict["head_a.weight"].shape[1]}
5557

5658

59+
@pytest.fixture
60+
def device():
61+
"""Returns 'cpu' if on macOS in GitHub Actions, otherwise None."""
62+
if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform == "darwin":
63+
return "cpu"
64+
return None
65+
66+
5767
@pytest.fixture(autouse=True)
5868
def mps_memory_cleanup():
5969
"""Fixture to clean up MPS memory after each test."""
@@ -78,7 +88,7 @@ def batch_message():
7888

7989

8090
@pytest.mark.parametrize("input_size,output_size", [(4, 2), (6, 3), (8, 1)])
81-
def test_inference_shapes(input_size, output_size):
91+
def test_inference_shapes(input_size, output_size, device):
8292
data = np.random.randn(12, input_size)
8393
msg = AxisArray(
8494
data=data,
@@ -94,6 +104,7 @@ def test_inference_shapes(input_size, output_size):
94104
"input_size": input_size,
95105
"output_size": output_size,
96106
},
107+
device=device,
97108
)
98109
out = proc(msg)[0]
99110
# Check output last dim matches output_size
@@ -138,13 +149,14 @@ def test_checkpoint_loading_and_weights(batch_message):
138149

139150

140151
@pytest.mark.parametrize("dropout", [0.0, 0.1, 0.5])
141-
def test_model_kwargs_propagation(dropout, batch_message):
152+
def test_model_kwargs_propagation(dropout, batch_message, device):
142153
proc = TorchModelProcessor(
143154
model_class=DUMMY_MODEL_CLASS,
144155
model_kwargs={
145156
"output_size": 2,
146157
"dropout": dropout,
147158
},
159+
device=device,
148160
)
149161
proc(batch_message)
150162
model = proc._state.model
@@ -155,14 +167,15 @@ def test_model_kwargs_propagation(dropout, batch_message):
155167
assert model.dropout is None
156168

157169

158-
def test_partial_fit_changes_weights(batch_message):
170+
def test_partial_fit_changes_weights(batch_message, device):
159171
proc = TorchModelProcessor(
160172
model_class=DUMMY_MODEL_CLASS,
161173
loss_fn=torch.nn.MSELoss(),
162174
learning_rate=0.1,
163175
model_kwargs={
164176
"output_size": 2,
165177
},
178+
device=device,
166179
)
167180
x = batch_message.data[:1]
168181
y = np.random.randn(1, 2)
@@ -198,6 +211,7 @@ def test_partial_fit_changes_weights(batch_message):
198211
"input_size": x.shape[-1],
199212
"output_size": 2,
200213
},
214+
device=device,
201215
)
202216
bad_proc(sample)
203217
with pytest.raises(ValueError):
@@ -209,8 +223,11 @@ def test_model_runs_on_devices(device, batch_message):
209223
# Skip unavailable devices
210224
if device == "cuda" and not torch.cuda.is_available():
211225
pytest.skip("CUDA not available")
212-
if device == "mps" and not torch.backends.mps.is_available():
213-
pytest.skip("MPS not available")
226+
if device == "mps":
227+
if not torch.backends.mps.is_available():
228+
pytest.skip("MPS not available")
229+
if os.getenv("GITHUB_ACTIONS") == "true":
230+
pytest.skip("MPS memory limit too low on free GitHub Actions runner")
214231

215232
proc = TorchModelProcessor(
216233
model_class=DUMMY_MODEL_CLASS,
@@ -226,7 +243,7 @@ def test_model_runs_on_devices(device, batch_message):
226243

227244

228245
@pytest.mark.parametrize("batch_size", [1, 5, 10])
229-
def test_batch_processing(batch_size):
246+
def test_batch_processing(batch_size, device):
230247
input_dim = 4
231248
output_dim = 2
232249
data = np.random.randn(batch_size, input_dim)
@@ -245,6 +262,7 @@ def test_batch_processing(batch_size):
245262
"input_size": input_dim,
246263
"output_size": output_dim,
247264
},
265+
device=device,
248266
)
249267
out = proc(msg)[0]
250268
assert out.data.shape[0] == batch_size
@@ -273,10 +291,11 @@ def test_input_size_mismatch_raises_error():
273291
)(msg)
274292

275293

276-
def test_multihead_output(batch_message):
294+
def test_multihead_output(batch_message, device):
277295
proc = TorchModelProcessor(
278296
model_class=MULTIHEAD_MODEL_CLASS,
279297
model_kwargs={"input_size": batch_message.data.shape[1]},
298+
device=device,
280299
)
281300
results = proc(batch_message)
282301

@@ -286,14 +305,15 @@ def test_multihead_output(batch_message):
286305
assert r.data.ndim == 2
287306

288307

289-
def test_multihead_partial_fit_with_loss_dict(batch_message):
308+
def test_multihead_partial_fit_with_loss_dict(batch_message, device):
290309
proc = TorchModelProcessor(
291310
model_class=MULTIHEAD_MODEL_CLASS,
292311
loss_fn={
293312
"head_a": torch.nn.MSELoss(),
294313
"head_b": torch.nn.L1Loss(),
295314
},
296315
model_kwargs={"input_size": batch_message.data.shape[1]},
316+
device=device,
297317
)
298318

299319
proc(batch_message) # initialize model
@@ -324,7 +344,7 @@ def test_multihead_partial_fit_with_loss_dict(batch_message):
324344
assert not torch.allclose(before_b, after_b)
325345

326346

327-
def test_partial_fit_with_loss_weights(batch_message):
347+
def test_partial_fit_with_loss_weights(batch_message, device):
328348
proc = TorchModelProcessor(
329349
model_class=MULTIHEAD_MODEL_CLASS,
330350
loss_fn={
@@ -336,6 +356,7 @@ def test_partial_fit_with_loss_weights(batch_message):
336356
"head_b": 0.5,
337357
},
338358
model_kwargs={"input_size": batch_message.data.shape[1]},
359+
device=device,
339360
)
340361
proc(batch_message)
341362

0 commit comments

Comments
 (0)