@@ -169,13 +169,16 @@ def drop_and_mask(keep_prob, seed=None):
169
169
170
170
dense_ = ops .Dense ()
171
171
def linear (input , weight , bias = None ):
172
- input = input .to (mindspore .float16 )
173
- weight = weight .to (mindspore .float16 )
174
- if bias is not None :
175
- bias = bias .to (mindspore .float16 )
176
- return dense_ (input , weight ) + bias
177
- return dense_ (input , weight )
178
-
172
+ if ON_ORANGE_PI :
173
+ input = input .to (mindspore .float16 )
174
+ weight = weight .to (mindspore .float16 )
175
+ if bias is not None :
176
+ bias = bias .to (mindspore .float16 )
177
+ return dense_ (input , weight ) + bias
178
+ return dense_ (input , weight )
179
+ if use_pyboost ():
180
+ return mindspore .mint .nn .functional .linear (input , weight , bias )
181
+ return dense_ (input , weight , bias )
179
182
180
183
181
184
def binary_cross_entropy_with_logits (input , target , weight = None , reduction = 'mean' , pos_weight = None ):
@@ -479,8 +482,8 @@ def addcmul(input, tensor1, tensor2, value=1):
479
482
return input + value * tensor1 * tensor2
480
483
481
484
def group_norm (input , num_groups , weight = None , bias = None , eps = 1e-5 ):
482
- # if use_pyboost():
483
- # return mindspore.mint.nn.functional.group_norm(input, num_groups, weight, bias, eps)
485
+ if use_pyboost ():
486
+ return mindspore .mint .nn .functional .group_norm (input , num_groups , weight , bias , eps )
484
487
485
488
input_shape = input .shape
486
489
N = input_shape [0 ]
@@ -491,8 +494,6 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
491
494
affine_param_shape = [1 ] * input .ndim
492
495
affine_param_shape [1 ] = C
493
496
affine_param_shape = tuple (affine_param_shape )
494
- print (affine_param_shape )
495
- print (out .shape )
496
497
if weight is not None and bias is not None :
497
498
# out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1)
498
499
out = addcmul (bias .view (affine_param_shape ), out , weight .view (affine_param_shape ), 1 )
0 commit comments