Skip to content

Commit

Permalink
move get_working_device into method
Browse files Browse the repository at this point in the history
  • Loading branch information
yardeny-sony committed Sep 25, 2024
1 parent 6e5bec9 commit 60853f1
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self):
Matches scaled_dot_product_attention node.
"""
super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))
self.device = get_working_device()

def _get_input_by_name(self, attention_node, input_name, input_index, default_value):
"""
Expand Down Expand Up @@ -135,6 +134,7 @@ def _get_matmul2_node(self, attention_node, softmax_node, v_node):
functional_op=torch.matmul)

def _get_attention_mask_tensor(self, attention_node):
device = get_working_device()
is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False)
if is_causal:
raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.")
Expand All @@ -146,7 +146,7 @@ def _get_attention_mask_tensor(self, attention_node):
if attn_mask is not None and (not np.isfinite(attn_mask).all()):
raise NotImplementedError(
"scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
return torch.from_numpy(attn_mask).to(self.device) if attn_mask is not None else None
return torch.from_numpy(attn_mask).to(device) if attn_mask is not None else None

def _get_dropout_node(self, attention_node, in_out_shape):
dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)
Expand Down

0 comments on commit 60853f1

Please sign in to comment.