From 7b6f584feeb3d5f6681f6ad56bf2aa2757f11bf1 Mon Sep 17 00:00:00 2001 From: Dakota Date: Wed, 12 Jul 2023 09:02:51 -0500 Subject: [PATCH] set saving check to dp_rank == 0 in case of different layouts. --- llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.py b/llama.py index 1a7a61f..46b5f48 100644 --- a/llama.py +++ b/llama.py @@ -967,7 +967,7 @@ def main(llama: Path, tokenizer: Path, tp_world: int, pp_world: int, save_to: Pa print("done", flush=True) torch.distributed.barrier() - if rank < (tensor_model_parallel_size * pipeline_model_parallel_size): + if dp_rank == 0: # save the state dict to sharded files os.makedirs(save_to, exist_ok=True) torch.save(