diff --git a/mmdet/models/utils/vlfuse_helper.py b/mmdet/models/utils/vlfuse_helper.py index 76b54de317c..073f912b94d 100644 --- a/mmdet/models/utils/vlfuse_helper.py +++ b/mmdet/models/utils/vlfuse_helper.py @@ -197,8 +197,10 @@ def forward( assert (attention_mask_l.dim() == 2) attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) - attention_mask = attention_mask.masked_fill( - attention_mask == 0, -9e15) + if attention_mask.dtype == torch.bool: + attention_mask = torch.zeros_like( + attention_mask, dtype=query_states.dtype).masked_fill_( + attention_mask, -9e15) if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError('Attention mask should be of '