diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f2120f3a73..ccceacff85 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1604,7 +1604,7 @@ def flash_attn_fwd_softmax_lse_correction( """Merge softmax stats of each step in Attention with context parallelism""" max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) softmax_lse.copy_(new_scale)