Skip to content

Commit f08bbdd

Browse files
author
Mark-ZhouWX
committed
optimize performance with bmm fp32->16
1 parent 22ead32 commit f08bbdd

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

research/segment-anything/segment_anything/modeling/transformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,13 @@ def construct(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
228228

229229
# Attention
230230
_, _, _, c_per_head = q.shape
231+
dtype = q.dtype
231232
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232-
attn = attn / Tensor(math.sqrt(c_per_head), ms.float32)
233+
attn = attn / Tensor(math.sqrt(c_per_head), dtype)
233234
attn = ops.softmax(attn, axis=-1)
234235

235236
# Get output
236-
dtype = attn.dtype
237-
out = attn @ v.astype(dtype)
237+
out = attn @ v
238238
out = self._recombine_heads(out)
239239
out = self.out_proj(out)
240240

0 commit comments

Comments
 (0)