From bea4fd1ee342515ef537a331fe13f118ce971c23 Mon Sep 17 00:00:00 2001 From: xiuqhou Date: Tue, 29 Apr 2025 22:37:37 +0800 Subject: [PATCH 1/2] [Bug fix] attention_mask dtype in BiMultiHeadAttention (#12351) --- mmdet/models/utils/vlfuse_helper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmdet/models/utils/vlfuse_helper.py b/mmdet/models/utils/vlfuse_helper.py index 76b54de317c..536a8bc0dd2 100644 --- a/mmdet/models/utils/vlfuse_helper.py +++ b/mmdet/models/utils/vlfuse_helper.py @@ -197,8 +197,9 @@ def forward( assert (attention_mask_l.dim() == 2) attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) - attention_mask = attention_mask.masked_fill( - attention_mask == 0, -9e15) + if attention_mask.dtype == torch.bool: + attention_mask = torch.zeros_like(attention_mask, dtype=query_states.dtype).masked_fill_( + attention_mask == True, -9e15) if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError('Attention mask should be of ' From a0c6d83dae9d2cd4c9433e435d1dab563f2029ae Mon Sep 17 00:00:00 2001 From: xiuqhou Date: Tue, 29 Apr 2025 23:06:22 +0800 Subject: [PATCH 2/2] Update code style --- mmdet/models/utils/vlfuse_helper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmdet/models/utils/vlfuse_helper.py b/mmdet/models/utils/vlfuse_helper.py index 536a8bc0dd2..073f912b94d 100644 --- a/mmdet/models/utils/vlfuse_helper.py +++ b/mmdet/models/utils/vlfuse_helper.py @@ -198,8 +198,9 @@ def forward( attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) if attention_mask.dtype == torch.bool: - attention_mask = torch.zeros_like(attention_mask, dtype=query_states.dtype).masked_fill_( - attention_mask == True, -9e15) + attention_mask = torch.zeros_like( + attention_mask, dtype=query_states.dtype).masked_fill_( + attention_mask, -9e15) if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError('Attention mask should be of '