@@ -463,11 +463,13 @@ def search(self,t):
463
463
return ops .argmax (t ,- 1 )
464
464
def call (self , inputs , ** kwargs ):
465
465
hidden_state ,update_index ,out_ids ,flags = inputs [:]
466
+
466
467
y = self .search (hidden_state )
467
468
t = ops .full_like (y ,self .end_token )
468
469
y = ops .where (flags ,y ,t )
469
470
start = [0 ,update_index ]
470
471
flags = y != self .end_token
472
+
471
473
return ops .slice_update (out_ids ,start ,ops .cast (y ,out_ids .dtype )),flags
472
474
473
475
class TopkSearch (GreedySearch ):
@@ -943,10 +945,10 @@ def build(self, input_shape):
943
945
def call (self , inputs ):
944
946
"""如果custom_position_ids,那么第二个输入为自定义的位置id
945
947
"""
946
- if self .custom_position_ids :
948
+ flag = isinstance (inputs ,list )
949
+ if self .custom_position_ids or flag :
947
950
inputs , position_ids = inputs
948
- if 'int' not in K .dtype (position_ids ):
949
- position_ids = ops .cast (position_ids , 'int32' )
951
+ position_ids = ops .cast (position_ids , 'int32' )
950
952
else :
951
953
input_shape = ops .shape (inputs )
952
954
batch_size , seq_len = input_shape [0 ], input_shape [1 ]
@@ -960,8 +962,8 @@ def call(self, inputs):
960
962
embeddings_y = ops .take (embeddings , position_ids % self .input_dim )
961
963
embeddings = alpha * embeddings_x + (1 - alpha ) * embeddings_y
962
964
else :
963
- if self .custom_position_ids :
964
- embeddings = ops .take (self .embeddings , position_ids )
965
+ if self .custom_position_ids or flag :
966
+ embeddings = ops .take (self .embeddings , position_ids , axis = 0 )
965
967
else :
966
968
embeddings = self .embeddings [None , :seq_len ]
967
969
0 commit comments