-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] Sagemaker integration #151
[DRAFT] Sagemaker integration #151
Conversation
Edited the file to remove TRI mentions. Now the region, ARN, and s3 path can be supplied through command line or env variable. Still not sure how to handle the train-data parameter. That might be a separate issue altogether, and we can just ignore it for this PR, but I rather like being able to easily supply something simple like "openlm_mix_tri_s3" to the train-data parameter Other than that, everything else here should work without issue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nits, otherwise looks good to me!
checkpoint_local_path = "/opt/ml/checkpoints" | ||
|
||
with open(args.cfg_path, "r") as f: | ||
train_args = yaml.safe_load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that openlm supports a --config, let's just pass the config to openlm directly, with --config args.cfg_path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried this but it seems that it gives some typing errors when passed through sagemaker:
Type mismatch (config: <class 'str'> vs. argparse: <class 'bool'>) with values (config: vs. argparse: False) for config. key: dataset_resampled
Leaving it as is for now
sagemaker_train/cfg_sample.yaml
Outdated
beta1: 0.9 | ||
beta2: 0.95 | ||
data-key: "json" | ||
dataset-resampled: "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you set this to True instead of empty string, the error you mentioned should go away when passing the config via path (and similarly for all other keys which have a "" value - set them to True instead)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried that and it still gives the same error. It seems to read everything as a string.
Just to double-check: The way to pass a config is to just do train_args = {"config": args.cfg_path}
instead of the yaml.safe_load(f), right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah that should be all you need. And you rebuilt the docker container after that change right? Will try it out later today, maybe there's something wrong with the parsing logic.
made some changes that should fix the issues with config here: https://github.com/sedrick-keh-tri/open_lm/pull/1. I think if you merge that PR into your branch, it should update this PR, and we should be good to merge. |
This still needs some fixing because there are some TRI-specific stuff in launch_sagemaker_train.py
It works if you run
python sagemaker_train/launch_sagemaker_train.py --user sedrick.keh --cfg-path sagemaker_train/cfg_sample.yaml --build
. For easy testing, I edited params.py to accept openlm_mix_tri_s3 in my local code (not pushed) but aside from that it works without any other changes