Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Sagemaker integration #151

Merged
merged 12 commits into from
Dec 20, 2023
6 changes: 3 additions & 3 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from .data import proc_token
from .model import Block
from .losses import CrossEntropyLossWithZLoss
from open_lm.data import proc_token
from open_lm.model import Block
from open_lm.losses import CrossEntropyLossWithZLoss

try:
import wandb
Expand Down
Empty file added sagemaker_train/.dockerignore
Empty file.
34 changes: 34 additions & 0 deletions sagemaker_train/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
ARG AWS_REGION

# SageMaker PyTorch image
FROM 763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-sagemaker

# Run custom installation of libraries
# RUN pip install xxx
# RUN apt-get update && apt-get install -y xxx
# ENV <your environment variables>
# etc....

# Remove the conda installed symlink for libcurl, which causes an error with curl.
# Fixes the following error:
# curl: /opt/conda/lib/libcurl.so.4: no version information available (required by curl)
RUN rm /opt/conda/lib/libcurl.so.4

ENV PATH="/opt/ml/code:${PATH}"

# this environment variable is used by the SageMaker PyTorch container to determine our user code directory.
ENV SAGEMAKER_SUBMIT_DIRECTORY /opt/ml/code

# /opt/ml and all subdirectories are utilized by SageMaker, use the /code subdirectory to store your user code.
COPY . /opt/ml/code/
RUN rm /opt/ml/code/setup.py

RUN pip install -r /opt/ml/code/requirements.txt
RUN pip uninstall flash-attn -y
RUN pip install flash-attn>=2.2
# # Prevent sagemaker from installing requirements again.
# RUN rm /opt/ml/code/setup.py
RUN rm /opt/ml/code/requirements.txt

# Defines a script entrypoint
ENV SAGEMAKER_PROGRAM open_lm/main.py
37 changes: 37 additions & 0 deletions sagemaker_train/cfg_sample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
accum-freq: 4
beta1: 0.9
beta2: 0.95
data-key: "json"
dataset-resampled: ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you set this to True instead of empty string, the error you mentioned should go away when passing the config via path (and similarly for all other keys which have a "" value - set them to True instead)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that and it still gives the same error. It seems to read everything as a string.

Just to double-check: The way to pass a config is to just do train_args = {"config": args.cfg_path} instead of the yaml.safe_load(f), right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah that should be all you need. And you rebuilt the docker container after that change right? Will try it out later today, maybe there's something wrong with the parsing logic.

# delete-previous-checkpoint: False
# Total 25B * 40 = 1T tokens
epochs: 40
fsdp: ""
fsdp-limit-all-gathers: ""
# grad-checkpointing: False
grad-clip-norm: 1
log-every-n-steps: 20
model: "open_lm_7b"
name: "sample_7b"
precision: "amp_bfloat16"
report-to: "wandb"
seed: 124
train-data-mix-weights: 0.725 0.275
train-data: TODO
train-num-samples: 25_000_000_000
wandb-project-name: "lm1"
workers: 4

# Some important parameters, double checked with Mitchell:
batch-size: 16
ffn-type: swiglu
# fsdp-amp: False
fsdp-pure-bf16: ""
fsdp-backward-prefetch: ""
lr: 3e-4
lr-cooldown-end: 3e-5
model-norm: "gain_only_lp_layer_norm"
qk-norm: ""
warmup: 5000
wd: 0.1
z-loss-coefficient: 1e-4
187 changes: 187 additions & 0 deletions sagemaker_train/launch_sagemaker_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import argparse
import time
import os
import subprocess
import yaml
from datetime import datetime

import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
from sagemaker.inputs import TrainingInput
from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
from sagemaker_train.sm_utils import get_arn, get_remote_sync


NAME = "openlm"
INSTANCE_MAPPER = {
"p4": "ml.p4d.24xlarge",
"p4de": "ml.p4de.24xlarge",
"p5": "ml.p5.48xlarge",
}


def run_command(command):
subprocess.run(command, shell=True, check=True)


def get_image(user, region, instance_type, build_image=False, update_image=False, profile="poweruser"):
os.environ["AWS_PROFILE"] = f"{profile}"
account = subprocess.getoutput(
f"aws --region {region} --profile {profile} sts get-caller-identity --query Account --output text"
)
algorithm_name = f"{user}-{NAME}"
fullname = f"{account}.dkr.ecr.{region}.amazonaws.com/{algorithm_name}:latest"
if not build_image and not update_image:
return fullname

login_cmd = f"aws ecr get-login-password --region {region} --profile {profile} | docker login --username AWS --password-stdin"

if build_image:
print("Building container")
commands = [
f"{login_cmd} 763104351884.dkr.ecr.{region}.amazonaws.com",
f"docker build -f sagemaker_train/Dockerfile --build-arg AWS_REGION={region} --build-arg DOCKER_IGNORE_FILE=sagemaker_train/.dockerignore -t {algorithm_name} .",
f"docker tag {algorithm_name} {fullname}",
f"{login_cmd} {fullname}",
f"aws --region {region} ecr describe-repositories --repository-names {algorithm_name} || aws --region {region} ecr create-repository --repository-name {algorithm_name}",
]
elif update_image:
print("Updating container")
commands = [
f"docker build -f sagemaker_train/update.dockerfile --build-arg DOCKER_IGNORE_FILE=sagemaker_train/.dockerignore --build-arg BASE_DOCKER={algorithm_name} -t {algorithm_name} .",
f"docker tag {algorithm_name} {fullname}",
f"{login_cmd} {fullname}",
]

print("\n".join(commands))
subprocess.run("\n".join(commands), shell=True)
run_command(f"docker push {fullname}")
print("Sleeping for 5 seconds to ensure push succeeded")
time.sleep(5)

return f"{account}.dkr.ecr.{region}.amazonaws.com/{algorithm_name}:latest"


def main():
# Use first line of file docstring as description if it exists.
parser = argparse.ArgumentParser()
parser.add_argument("--build", action="store_true", help="Build image from scratch")
parser.add_argument("--update", action="store_true", help="Update code in image, don't re-build")
parser.add_argument("--local", action="store_true")
parser.add_argument("--user", required=True, help="User name")
parser.add_argument("--cfg-path", required=True, help="Launch config")

# AWS profile args
parser.add_argument("--region", default="us-east-1", help="AWS region")
parser.add_argument("--profile", default="poweruser", help="AWS profile to use")
parser.add_argument("--arn", default=None, help="If None, reads from SAGEMAKER_ARN env var")
parser.add_argument(
"--s3-remote-sync", default=None, help="S3 path to sync to. If none, reads from S3_REMOTE_SYNC env var"
)

# Instance args
parser.add_argument("--instance-count", default=1, type=int, help="Number of instances")
parser.add_argument("--instance-type", default="p4de", choices=list(INSTANCE_MAPPER.keys()))
parser.add_argument("--spot-instance", action="store_true")
args = parser.parse_args()

setup_tmp_name = "./setup_renamed_for_sagemaker.py"
# print(f"Renaming ./setup.py to {setup_tmp_name}")
# os.rename("./setup.py", setup_tmp_name)
try:
main_after_setup_move(args)
except:
# os.rename(setup_tmp_name, "./setup.py")
raise


def main_after_setup_move(args):
image = get_image(
args.user,
args.region,
args.instance_type,
build_image=args.build,
update_image=args.update,
profile=args.profile,
)

##########
# Create session and make sure of account and region
##########
sagemaker_session = sagemaker.Session(boto_session=boto3.session.Session(region_name=args.region))

# provide a pre-existing role ARN as an alternative to creating a new role
role = get_arn(args.arn)
role_name = role.split(["/"][-1])
print(f"SageMaker Execution Role:{role}")
print(f"The name of the Execution role: {role_name[-1]}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account:{account}")

session = boto3.session.Session()
region = session.region_name
print(f"AWS region:{region}")

##########
# Configure the training
##########
base_job_name = f"{args.user.replace('.', '-')}-{NAME}"

checkpoint_local_path = "/opt/ml/checkpoints"

with open(args.cfg_path, "r") as f:
train_args = yaml.safe_load(f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that openlm supports a --config, let's just pass the config to openlm directly, with --config args.cfg_path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this but it seems that it gives some typing errors when passed through sagemaker:
Type mismatch (config: <class 'str'> vs. argparse: <class 'bool'>) with values (config: vs. argparse: False) for config. key: dataset_resampled

Leaving it as is for now

train_args["logs"] = checkpoint_local_path if not args.local else "./logs/debug"

def get_job_name(base, train_args):
now = datetime.now()
# Format example: 2023-03-03-10-14-02-324
now_ms_str = f"{now.microsecond // 1000:03d}"
date_str = f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{now_ms_str}"

job_name = "_".join([base, date_str])

return job_name

job_name = get_job_name(base_job_name, train_args)

s3_remote_sync = get_remote_sync(args.s3_remote_sync)
output_root = f"{s3_remote_sync}/sagemaker/{args.user}/{NAME}/"
output_s3 = os.path.join(output_root, job_name)

estimator = PyTorch(
entry_point="open_lm/main.py",
base_job_name=base_job_name,
hyperparameters=train_args,
role=role,
image_uri=image,
instance_count=int(args.instance_count),
instance_type="local_gpu" if args.local else INSTANCE_MAPPER[args.instance_type],
train_use_spot_instances=True if args.spot_instance else False,
# sagemaker_session=sagemaker_session,
output_path=output_s3,
job_name=job_name,
checkpoint_s3_uri=None if args.local else f"{output_s3}/checkpoint",
checkpoint_local_path=None if args.local else checkpoint_local_path,
code_location=output_s3,
# Training using SMDataParallel Distributed Training Framework
distribution={"torch_distributed": {"enabled": True}},
# Max run 10 days
max_run=5 * 24 * 60 * 60,
max_wait=5 * 24 * 60 * 60 if args.spot_instance else None,
# max_run=60 * 60, # 60 minutes
input_mode="FastFile",
# environment={"TORCH_DISTRIBUTED_DEBUG": "DETAIL", "TORCH_CPP_LOG_LEVEL": "INFO"},
keep_alive_period_in_seconds=30 * 60 if not args.spot_instance else None, # 30 minutes
dependencies=[SSHEstimatorWrapper.dependency_dir()],
)

estimator.fit()


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions sagemaker_train/sm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os


def get_arn(arn):
if arn is not None:
return arn
else:
return os.environ["SAGEMAKER_ARN"]


def get_remote_sync(s3_remote_sync):
if s3_remote_sync is not None:
return s3_remote_sync
else:
return os.environ["S3_REMOTE_SYNC"]
14 changes: 14 additions & 0 deletions sagemaker_train/update.dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
ARG BASE_DOCKER
# Dockerfile that updates the container with new code.
# SageMaker PyTorch image
FROM ${BASE_DOCKER}

# /opt/ml and all subdirectories are utilized by SageMaker, use the /code subdirectory to store your user code.
COPY . /opt/ml/code/

# RUN pip install -e /opt/ml/code/

# # Prevent sagemaker from installing requirements again.
RUN rm /opt/ml/code/requirements.txt

ENV SAGEMAKER_PROGRAM open_lm/main.py
Loading