Skip to content

Commit 48d8e7d

Browse files
authored
Add files via upload
1 parent 7e23d60 commit 48d8e7d

File tree

4 files changed

+196
-33
lines changed

4 files changed

+196
-33
lines changed
12 Bytes
Binary file not shown.
2.9 KB
Binary file not shown.

bert4keras3/layers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,13 @@ def search(self,t):
463463
return ops.argmax(t,-1)
464464
def call(self, inputs, **kwargs):
465465
hidden_state,update_index,out_ids,flags = inputs[:]
466+
466467
y = self.search(hidden_state)
467468
t = ops.full_like(y,self.end_token)
468469
y = ops.where(flags,y,t)
469470
start = [0,update_index]
470471
flags = y!=self.end_token
472+
471473
return ops.slice_update(out_ids,start,ops.cast(y,out_ids.dtype)),flags
472474

473475
class TopkSearch(GreedySearch):
@@ -943,10 +945,10 @@ def build(self, input_shape):
943945
def call(self, inputs):
944946
"""如果custom_position_ids,那么第二个输入为自定义的位置id
945947
"""
946-
if self.custom_position_ids:
948+
flag = isinstance(inputs,list)
949+
if self.custom_position_ids or flag :
947950
inputs, position_ids = inputs
948-
if 'int' not in K.dtype(position_ids):
949-
position_ids = ops.cast(position_ids, 'int32')
951+
position_ids = ops.cast(position_ids, 'int32')
950952
else:
951953
input_shape = ops.shape(inputs)
952954
batch_size, seq_len = input_shape[0], input_shape[1]
@@ -960,8 +962,8 @@ def call(self, inputs):
960962
embeddings_y = ops.take(embeddings, position_ids % self.input_dim)
961963
embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y
962964
else:
963-
if self.custom_position_ids:
964-
embeddings = ops.take(self.embeddings, position_ids)
965+
if self.custom_position_ids or flag :
966+
embeddings = ops.take(self.embeddings, position_ids,axis=0)
965967
else:
966968
embeddings = self.embeddings[None, :seq_len]
967969

0 commit comments

Comments
 (0)