@@ -217,48 +217,47 @@ def set_dtype(model, dtype):
217217 return model
218218
219219
220+ def _generate_inputs (base_input , dtype , framework ):
221+ if isinstance (base_input , np .ndarray ):
222+ if base_input .dtype in (np .float16 , np .float32 , np .float64 , bfloat16 ):
223+ base_input = base_input .astype (dtype )
224+
225+ if framework == "torch" :
226+ return (
227+ torch .from_numpy (base_input .astype (np .float32 )).to (torch .bfloat16 )
228+ if dtype == bfloat16
229+ else torch .from_numpy (base_input )
230+ )
231+ elif framework == "mindspore" :
232+ return ms .Tensor .from_numpy (base_input )
233+ else :
234+ raise ValueError (f"Unsupported framework: { framework } " )
235+
236+ elif isinstance (base_input , (tuple , list )):
237+ sequence_cls = type (base_input )
238+ return sequence_cls (_generate_inputs (x , dtype ) for x in base_input )
239+
240+ elif isinstance (base_input , dict ):
241+ return {k : _generate_inputs (v , dtype ) for k , v in base_input .items ()}
242+
243+ else :
244+ return base_input
245+
246+
220247def generalized_parse_args (pt_dtype , ms_dtype , * args , ** kwargs ):
221248 # parse args
222249 pt_inputs_args = tuple ()
223250 ms_inputs_args = tuple ()
224251 for x in args :
225- if isinstance (x , np .ndarray ):
226- if x .dtype in (np .float16 , np .float32 , np .float64 , bfloat16 ):
227- px = x .astype (NP_DTYPE_MAPPING [pt_dtype ])
228- mx = x .astype (NP_DTYPE_MAPPING [ms_dtype ])
229- else :
230- px = mx = x
231-
232- pt_inputs_args += (
233- (torch .from_numpy (px .astype (np .float32 )).to (torch .bfloat16 ),)
234- if pt_dtype == "bf16"
235- else (torch .from_numpy (px ),)
236- )
237- ms_inputs_args += (ms .Tensor .from_numpy (mx ),)
238- else :
239- pt_inputs_args += (x ,)
240- ms_inputs_args += (x ,)
252+ pt_inputs_args += (_generate_inputs (x , NP_DTYPE_MAPPING [pt_dtype ], "torch" ),)
253+ ms_inputs_args += (_generate_inputs (x , NP_DTYPE_MAPPING [ms_dtype ], "mindspore" ),)
241254
242255 # parse kwargs
243256 pt_inputs_kwargs = dict ()
244257 ms_inputs_kwargs = dict ()
245258 for k , v in kwargs .items ():
246- if isinstance (v , np .ndarray ):
247- if v .dtype in (np .float16 , np .float32 , np .float64 , bfloat16 ):
248- px = v .astype (NP_DTYPE_MAPPING [pt_dtype ])
249- mx = v .astype (NP_DTYPE_MAPPING [ms_dtype ])
250- else :
251- px = mx = v
252-
253- pt_inputs_kwargs [k ] = (
254- torch .from_numpy (px .astype (np .float32 )).to (torch .bfloat16 )
255- if pt_dtype == "bf16"
256- else torch .from_numpy (px )
257- )
258- ms_inputs_kwargs [k ] = ms .Tensor .from_numpy (mx )
259- else :
260- pt_inputs_kwargs [k ] = v
261- ms_inputs_kwargs [k ] = v
259+ pt_inputs_kwargs [k ] = _generate_inputs (v , NP_DTYPE_MAPPING [pt_dtype ], "torch" )
260+ ms_inputs_kwargs [k ] = _generate_inputs (v , NP_DTYPE_MAPPING [ms_dtype ], "mindspore" )
262261
263262 return pt_inputs_args , pt_inputs_kwargs , ms_inputs_args , ms_inputs_kwargs
264263
0 commit comments