From 30133d57fc205488c2a7e1ea110ee46e93cb6801 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:09:17 +0800 Subject: [PATCH] [Bugs] Fix attn mask (#852) * [WIP]: Fix sequence parallel memory bottleneck in DPO * loss mask before split * refactor orpo * fix attention_mask in preference_collate_fn --------- Co-authored-by: RangiLyu --- xtuner/dataset/collate_fns/preference_collate_fn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xtuner/dataset/collate_fns/preference_collate_fn.py b/xtuner/dataset/collate_fns/preference_collate_fn.py index 8a6060410..ca21613bb 100644 --- a/xtuner/dataset/collate_fns/preference_collate_fn.py +++ b/xtuner/dataset/collate_fns/preference_collate_fn.py @@ -58,7 +58,7 @@ def preference_collate_fn(instances: Sequence[Dict], labels = torch.stack(labels) if use_varlen_attn: - attention_mask = None + attention_mask = torch.ones_like(input_ids).bool() position_ids = torch.stack(position_ids, dim=0) else: # Some tokenizers have the same eos token and pad token, so input_ids @@ -74,8 +74,10 @@ def preference_collate_fn(instances: Sequence[Dict], input_ids = pad_for_sequence_parallel(input_ids, pad_index) labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) position_ids = pad_for_sequence_parallel(position_ids, 0) - if attention_mask is not None: - attention_mask = pad_for_sequence_parallel(attention_mask, 0) + # We use attention_mask to distinguish `input_ids` from + # (sequence parallel) pad tokens in `get_var_len_atten_logps` method of + # class `DPO` and `ORPO` + attention_mask = pad_for_sequence_parallel(attention_mask, 0) if use_varlen_attn: (cumulative_len, attention_mask ) = pad_cumulative_len_for_sequence_parallel(cumulative_len)