diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py index 22c14844a..7a9d5dfaf 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py @@ -28,7 +28,6 @@ def __init__(self): Matches scaled_dot_product_attention node. """ super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention)) - self.device = get_working_device() def _get_input_by_name(self, attention_node, input_name, input_index, default_value): """ @@ -135,6 +134,7 @@ def _get_matmul2_node(self, attention_node, softmax_node, v_node): functional_op=torch.matmul) def _get_attention_mask_tensor(self, attention_node): + device = get_working_device() is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False) if is_causal: raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.") @@ -146,7 +146,7 @@ def _get_attention_mask_tensor(self, attention_node): if attn_mask is not None and (not np.isfinite(attn_mask).all()): raise NotImplementedError( "scaled_dot_product_attention attn_mask contains infinite value, which is not supported.") - return torch.from_numpy(attn_mask).to(self.device) if attn_mask is not None else None + return torch.from_numpy(attn_mask).to(device) if attn_mask is not None else None def _get_dropout_node(self, attention_node, in_out_shape): dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)