Skip to content

Commit

Permalink
NCCL_CUMEM_ENABLE fix. (#517)
Browse files Browse the repository at this point in the history
* push changes

* push the change

* revert leaderboard changes

* push change

* fix
  • Loading branch information
vwxyzjn authored Jan 14, 2025
1 parent ec76d40 commit 9582883
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
extend-ignore = E203
extend-ignore = E203,E402
2 changes: 1 addition & 1 deletion mason.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse_beaker_dataset(dataset_str):
"ai2/neptune-cirrascale",
"ai2/allennlp-elara-cirrascale",
"ai2/ceres-cirrascale",

"ai2/ganymede-cirrascale",
]
GCP_CLUSTERS = [
"ai2/augusta-google-1"
Expand Down
34 changes: 21 additions & 13 deletions open_instruct/ppo_vllm_thread_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from queue import Empty, Queue
from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple

os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA

import deepspeed
import numpy as np
import pandas as pd
Expand All @@ -52,7 +54,6 @@
import torch.nn.functional as F
import torch.utils
import torch.utils.data
import vllm
from datasets import Dataset, DatasetDict
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from huggingface_hub import HfApi
Expand Down Expand Up @@ -744,11 +745,11 @@ def train(
world_size = vllm_num_engines * vllm_tensor_parallel_size + 1
backend = args.vllm_sync_backend
# https://github.com/OpenRLHF/OpenRLHF/issues/313
if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0":
backend = "gloo"
print(
"Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)"
)
# if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0":
# backend = "gloo"
# print(
# "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)"
# )
refs = [
engine.init_process_group.remote(
master_address,
Expand Down Expand Up @@ -967,10 +968,10 @@ def vllm_generate(

start_time = time.time()
broadcast_to_vllm()
print(
f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds"
)
if accelerator.is_main_process:
print(
f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds"
)
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))
else:
if training_step != 1:
Expand All @@ -986,10 +987,10 @@ def vllm_generate(
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
start_time = time.time()
broadcast_to_vllm()
print(
f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds"
)
if accelerator.is_main_process:
print(
f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds"
)
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))
queries = queries_next

Expand Down Expand Up @@ -1302,7 +1303,12 @@ def vllm_generate(

# Ai2 logic: we use /output to store the artifacts of the job, so we
# make a copy of the model to `/output` in the end.
if args.try_auto_save_to_beaker and self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 and args.output_dir != "/output":
if (
args.try_auto_save_to_beaker
and self.rank == 0
and len(self.beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir != "/output"
):
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")

Expand Down Expand Up @@ -1408,6 +1414,8 @@ def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] =
--run_oe_eval_experiments \
--evaluate_on_weka \
--run_safety_evaluations \
--run_id {wandb_url} \
--step {training_step} \
--skip_oi_evals"""
if args.oe_eval_tasks is not None:
command += f" --oe_eval_tasks {','.join(args.oe_eval_tasks)}"
Expand Down
16 changes: 13 additions & 3 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from queue import Empty, Queue
from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple

os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA

import deepspeed
import numpy as np
import pandas as pd
Expand All @@ -52,7 +54,6 @@
import torch.nn.functional as F
import torch.utils
import torch.utils.data
import vllm
from datasets import Dataset, DatasetDict
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from huggingface_hub import HfApi
Expand Down Expand Up @@ -714,7 +715,9 @@ def from_pretrained(
self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config)
self.reward_model.eval()

assert args.reward_model_multiplier or args.apply_verifiable_reward, "Either `reward_model_multiplier` must be non-zero or `apply_verifiable_reward` must be True."
assert (
args.reward_model_multiplier or args.apply_verifiable_reward
), "Either `reward_model_multiplier` must be non-zero or `apply_verifiable_reward` must be True."

def get_vocab_size(self):
return self.policy.config.vocab_size
Expand Down Expand Up @@ -1377,7 +1380,12 @@ def vllm_generate(

# Ai2 logic: we use /output to store the artifacts of the job, so we
# make a copy of the model to `/output` in the end.
if args.try_auto_save_to_beaker and self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 and args.output_dir != "/output":
if (
args.try_auto_save_to_beaker
and self.rank == 0
and len(self.beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir != "/output"
):
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")

Expand Down Expand Up @@ -1483,6 +1491,8 @@ def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] =
--run_oe_eval_experiments \
--evaluate_on_weka \
--run_safety_evaluations \
--run_id {wandb_url} \
--step {training_step} \
--skip_oi_evals"""
if args.oe_eval_tasks is not None:
command += f" --oe_eval_tasks {','.join(args.oe_eval_tasks)}"
Expand Down

0 comments on commit 9582883

Please sign in to comment.