Skip to content

Commit 8c842ec

Browse files
committed
refactor: refactor test utils
1 parent 96b37dc commit 8c842ec

File tree

1 file changed

+31
-32
lines changed

1 file changed

+31
-32
lines changed

tests/modeling_test_utils.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -217,48 +217,47 @@ def set_dtype(model, dtype):
217217
return model
218218

219219

220+
def _generate_inputs(base_input, dtype, framework):
221+
if isinstance(base_input, np.ndarray):
222+
if base_input.dtype in (np.float16, np.float32, np.float64, bfloat16):
223+
base_input = base_input.astype(dtype)
224+
225+
if framework == "torch":
226+
return (
227+
torch.from_numpy(base_input.astype(np.float32)).to(torch.bfloat16)
228+
if dtype == bfloat16
229+
else torch.from_numpy(base_input)
230+
)
231+
elif framework == "mindspore":
232+
return ms.Tensor.from_numpy(base_input)
233+
else:
234+
raise ValueError(f"Unsupported framework: {framework}")
235+
236+
elif isinstance(base_input, (tuple, list)):
237+
sequence_cls = type(base_input)
238+
return sequence_cls(_generate_inputs(x, dtype) for x in base_input)
239+
240+
elif isinstance(base_input, dict):
241+
return {k: _generate_inputs(v, dtype) for k, v in base_input.items()}
242+
243+
else:
244+
return base_input
245+
246+
220247
def generalized_parse_args(pt_dtype, ms_dtype, *args, **kwargs):
221248
# parse args
222249
pt_inputs_args = tuple()
223250
ms_inputs_args = tuple()
224251
for x in args:
225-
if isinstance(x, np.ndarray):
226-
if x.dtype in (np.float16, np.float32, np.float64, bfloat16):
227-
px = x.astype(NP_DTYPE_MAPPING[pt_dtype])
228-
mx = x.astype(NP_DTYPE_MAPPING[ms_dtype])
229-
else:
230-
px = mx = x
231-
232-
pt_inputs_args += (
233-
(torch.from_numpy(px.astype(np.float32)).to(torch.bfloat16),)
234-
if pt_dtype == "bf16"
235-
else (torch.from_numpy(px),)
236-
)
237-
ms_inputs_args += (ms.Tensor.from_numpy(mx),)
238-
else:
239-
pt_inputs_args += (x,)
240-
ms_inputs_args += (x,)
252+
pt_inputs_args += (_generate_inputs(x, NP_DTYPE_MAPPING[pt_dtype], "torch"),)
253+
ms_inputs_args += (_generate_inputs(x, NP_DTYPE_MAPPING[ms_dtype], "mindspore"),)
241254

242255
# parse kwargs
243256
pt_inputs_kwargs = dict()
244257
ms_inputs_kwargs = dict()
245258
for k, v in kwargs.items():
246-
if isinstance(v, np.ndarray):
247-
if v.dtype in (np.float16, np.float32, np.float64, bfloat16):
248-
px = v.astype(NP_DTYPE_MAPPING[pt_dtype])
249-
mx = v.astype(NP_DTYPE_MAPPING[ms_dtype])
250-
else:
251-
px = mx = v
252-
253-
pt_inputs_kwargs[k] = (
254-
torch.from_numpy(px.astype(np.float32)).to(torch.bfloat16)
255-
if pt_dtype == "bf16"
256-
else torch.from_numpy(px)
257-
)
258-
ms_inputs_kwargs[k] = ms.Tensor.from_numpy(mx)
259-
else:
260-
pt_inputs_kwargs[k] = v
261-
ms_inputs_kwargs[k] = v
259+
pt_inputs_kwargs[k] = _generate_inputs(v, NP_DTYPE_MAPPING[pt_dtype], "torch")
260+
ms_inputs_kwargs[k] = _generate_inputs(v, NP_DTYPE_MAPPING[ms_dtype], "mindspore")
262261

263262
return pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs
264263

0 commit comments

Comments
 (0)