11from pathlib import Path
2+ import os
3+ import sys
24
35import numpy as np
46import pytest
@@ -54,6 +56,14 @@ def infer_config_from_state_dict(cls, state_dict):
5456 return {"input_size" : state_dict ["head_a.weight" ].shape [1 ]}
5557
5658
59+ @pytest .fixture
60+ def device ():
61+ """Returns 'cpu' if on macOS in GitHub Actions, otherwise None."""
62+ if os .getenv ("GITHUB_ACTIONS" ) == "true" and sys .platform == "darwin" :
63+ return "cpu"
64+ return None
65+
66+
5767@pytest .fixture (autouse = True )
5868def mps_memory_cleanup ():
5969 """Fixture to clean up MPS memory after each test."""
@@ -78,7 +88,7 @@ def batch_message():
7888
7989
8090@pytest .mark .parametrize ("input_size,output_size" , [(4 , 2 ), (6 , 3 ), (8 , 1 )])
81- def test_inference_shapes (input_size , output_size ):
91+ def test_inference_shapes (input_size , output_size , device ):
8292 data = np .random .randn (12 , input_size )
8393 msg = AxisArray (
8494 data = data ,
@@ -94,6 +104,7 @@ def test_inference_shapes(input_size, output_size):
94104 "input_size" : input_size ,
95105 "output_size" : output_size ,
96106 },
107+ device = device ,
97108 )
98109 out = proc (msg )[0 ]
99110 # Check output last dim matches output_size
@@ -138,13 +149,14 @@ def test_checkpoint_loading_and_weights(batch_message):
138149
139150
140151@pytest .mark .parametrize ("dropout" , [0.0 , 0.1 , 0.5 ])
141- def test_model_kwargs_propagation (dropout , batch_message ):
152+ def test_model_kwargs_propagation (dropout , batch_message , device ):
142153 proc = TorchModelProcessor (
143154 model_class = DUMMY_MODEL_CLASS ,
144155 model_kwargs = {
145156 "output_size" : 2 ,
146157 "dropout" : dropout ,
147158 },
159+ device = device ,
148160 )
149161 proc (batch_message )
150162 model = proc ._state .model
@@ -155,14 +167,15 @@ def test_model_kwargs_propagation(dropout, batch_message):
155167 assert model .dropout is None
156168
157169
158- def test_partial_fit_changes_weights (batch_message ):
170+ def test_partial_fit_changes_weights (batch_message , device ):
159171 proc = TorchModelProcessor (
160172 model_class = DUMMY_MODEL_CLASS ,
161173 loss_fn = torch .nn .MSELoss (),
162174 learning_rate = 0.1 ,
163175 model_kwargs = {
164176 "output_size" : 2 ,
165177 },
178+ device = device ,
166179 )
167180 x = batch_message .data [:1 ]
168181 y = np .random .randn (1 , 2 )
@@ -198,6 +211,7 @@ def test_partial_fit_changes_weights(batch_message):
198211 "input_size" : x .shape [- 1 ],
199212 "output_size" : 2 ,
200213 },
214+ device = device ,
201215 )
202216 bad_proc (sample )
203217 with pytest .raises (ValueError ):
@@ -209,8 +223,11 @@ def test_model_runs_on_devices(device, batch_message):
209223 # Skip unavailable devices
210224 if device == "cuda" and not torch .cuda .is_available ():
211225 pytest .skip ("CUDA not available" )
212- if device == "mps" and not torch .backends .mps .is_available ():
213- pytest .skip ("MPS not available" )
226+ if device == "mps" :
227+ if not torch .backends .mps .is_available ():
228+ pytest .skip ("MPS not available" )
229+ if os .getenv ("GITHUB_ACTIONS" ) == "true" :
230+ pytest .skip ("MPS memory limit too low on free GitHub Actions runner" )
214231
215232 proc = TorchModelProcessor (
216233 model_class = DUMMY_MODEL_CLASS ,
@@ -226,7 +243,7 @@ def test_model_runs_on_devices(device, batch_message):
226243
227244
228245@pytest .mark .parametrize ("batch_size" , [1 , 5 , 10 ])
229- def test_batch_processing (batch_size ):
246+ def test_batch_processing (batch_size , device ):
230247 input_dim = 4
231248 output_dim = 2
232249 data = np .random .randn (batch_size , input_dim )
@@ -245,6 +262,7 @@ def test_batch_processing(batch_size):
245262 "input_size" : input_dim ,
246263 "output_size" : output_dim ,
247264 },
265+ device = device ,
248266 )
249267 out = proc (msg )[0 ]
250268 assert out .data .shape [0 ] == batch_size
@@ -273,10 +291,11 @@ def test_input_size_mismatch_raises_error():
273291 )(msg )
274292
275293
276- def test_multihead_output (batch_message ):
294+ def test_multihead_output (batch_message , device ):
277295 proc = TorchModelProcessor (
278296 model_class = MULTIHEAD_MODEL_CLASS ,
279297 model_kwargs = {"input_size" : batch_message .data .shape [1 ]},
298+ device = device ,
280299 )
281300 results = proc (batch_message )
282301
@@ -286,14 +305,15 @@ def test_multihead_output(batch_message):
286305 assert r .data .ndim == 2
287306
288307
289- def test_multihead_partial_fit_with_loss_dict (batch_message ):
308+ def test_multihead_partial_fit_with_loss_dict (batch_message , device ):
290309 proc = TorchModelProcessor (
291310 model_class = MULTIHEAD_MODEL_CLASS ,
292311 loss_fn = {
293312 "head_a" : torch .nn .MSELoss (),
294313 "head_b" : torch .nn .L1Loss (),
295314 },
296315 model_kwargs = {"input_size" : batch_message .data .shape [1 ]},
316+ device = device ,
297317 )
298318
299319 proc (batch_message ) # initialize model
@@ -324,7 +344,7 @@ def test_multihead_partial_fit_with_loss_dict(batch_message):
324344 assert not torch .allclose (before_b , after_b )
325345
326346
327- def test_partial_fit_with_loss_weights (batch_message ):
347+ def test_partial_fit_with_loss_weights (batch_message , device ):
328348 proc = TorchModelProcessor (
329349 model_class = MULTIHEAD_MODEL_CLASS ,
330350 loss_fn = {
@@ -336,6 +356,7 @@ def test_partial_fit_with_loss_weights(batch_message):
336356 "head_b" : 0.5 ,
337357 },
338358 model_kwargs = {"input_size" : batch_message .data .shape [1 ]},
359+ device = device ,
339360 )
340361 proc (batch_message )
341362
0 commit comments