Skip to content

Commit

Permalink
adding support for local session (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorVasiljevic-TRI authored Dec 21, 2023
1 parent a79aa35 commit 685bc51
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions sagemaker_train/launch_sagemaker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run_command(command):
subprocess.run(command, shell=True, check=True)


def get_image(user, instance_type, build_type=None, profile="poweruser", region="us-east-1"):
def get_image(user, instance_type, build_type=None, profile="default", region="us-east-1"):
os.environ["AWS_PROFILE"] = f"{profile}"
account = subprocess.getoutput(
f"aws --region {region} --profile {profile} sts get-caller-identity --query Account --output text"
Expand Down Expand Up @@ -87,7 +87,7 @@ def main():

# AWS profile args
parser.add_argument("--region", default="us-east-1", help="AWS region")
parser.add_argument("--profile", default="poweruser", help="AWS profile to use")
parser.add_argument("--profile", default="default", help="AWS profile to use")
parser.add_argument("--arn", default=None, help="If None, reads from SAGEMAKER_ARN env var")
parser.add_argument(
"--s3-remote-sync", default=None, help="S3 path to sync to. If none, reads from S3_REMOTE_SYNC env var"
Expand Down Expand Up @@ -126,6 +126,11 @@ def main_after_setup_move(args):
##########
sagemaker_session = sagemaker.Session(boto_session=boto3.session.Session(region_name=args.region))

if args.local:
from sagemaker.local import LocalSession

sagemaker_session = LocalSession()

role = args.arn
# provide a pre-existing role ARN as an alternative to creating a new role
role_name = role.split(["/"][-1])
Expand Down

0 comments on commit 685bc51

Please sign in to comment.