@@ -138,7 +138,6 @@ def _set_base_path_env(): # type: () -> None
138138 str: the path to the intermediate output directory, e.g. /opt/ml/output/intermediate.
139139"""
140140
141-
142141HYPERPARAMETERS_FILE = 'hyperparameters.json' # type: str
143142RESOURCE_CONFIG_FILE = 'resourceconfig.json' # type: str
144143INPUT_DATA_CONFIG_FILE = 'inputdataconfig.json' # type: str
@@ -164,7 +163,7 @@ def _create_training_directories():
164163
165164 resources_dict = {
166165 "current_host" : host_name ,
167- "hosts" : [host_name ]
166+ "hosts" : [host_name ]
168167 }
169168 _write_json (resources_dict , resource_config_file_dir )
170169
@@ -534,11 +533,11 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
534533 resource_config = resource_config or read_resource_config ()
535534 current_host = resource_config ['current_host' ]
536535 hosts = resource_config ['hosts' ]
537- network_interface_name = resource_config .get ('network_interface_name' , 'ethwe' )
538536 input_data_config = input_data_config or read_input_data_config ()
539537
540538 all_hyperparameters = hyperparameters or read_hyperparameters ()
541- split_result = _mapping .split_by_criteria (all_hyperparameters , keys = _params .SAGEMAKER_HYPERPARAMETERS ,
539+ split_result = _mapping .split_by_criteria (all_hyperparameters ,
540+ keys = _params .SAGEMAKER_HYPERPARAMETERS ,
542541 prefix = _params .SAGEMAKER_PREFIX )
543542
544543 sagemaker_hyperparameters = split_result .included
@@ -547,14 +546,17 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
547546 if k not in _params .SAGEMAKER_HYPERPARAMETERS
548547 }
549548
550- sagemaker_region = sagemaker_hyperparameters .get (_params .REGION_NAME_PARAM , boto3 .session .Session ().region_name )
549+ sagemaker_region = sagemaker_hyperparameters .get (_params .REGION_NAME_PARAM ,
550+ boto3 .session .Session ().region_name )
551551
552552 os .environ [_params .JOB_NAME_ENV ] = sagemaker_hyperparameters .get (_params .JOB_NAME_PARAM , '' )
553553 os .environ [_params .CURRENT_HOST_ENV ] = current_host
554554 os .environ [_params .REGION_NAME_ENV ] = sagemaker_region or ''
555555
556556 self ._hosts = hosts
557- self ._network_interface_name = network_interface_name
557+
558+ self ._network_interface_name = resource_config .get ('network_interface_name' , 'eth0' )
559+
558560 self ._hyperparameters = split_result .excluded
559561 self ._additional_framework_parameters = additional_framework_parameters
560562 self ._resource_config = resource_config
@@ -567,7 +569,8 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
567569 # override base class attributes
568570 if self ._module_name is None :
569571 self ._module_name = str (sagemaker_hyperparameters .get (_params .USER_PROGRAM_PARAM , None ))
570- self ._user_entry_point = self ._user_entry_point or sagemaker_hyperparameters .get (_params .USER_PROGRAM_PARAM )
572+ self ._user_entry_point = self ._user_entry_point or sagemaker_hyperparameters .get (
573+ _params .USER_PROGRAM_PARAM )
571574
572575 self ._module_dir = str (sagemaker_hyperparameters .get (_params .SUBMIT_DIR_PARAM , code_dir ))
573576 self ._log_level = sagemaker_hyperparameters .get (_params .LOG_LEVEL_PARAM , logging .INFO )
@@ -580,6 +583,21 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
580583 self ._output_dir = output_dir
581584 self ._job_name = os .environ .get (_params .TRAINING_JOB_ENV .upper (), None )
582585
586+ self ._master_hostname = list (hosts )[0 ]
587+ self ._is_master = current_host == self ._master_hostname
588+
589+ @property
590+ def is_master (self ): # type: () -> bool
591+ """Returns True if host is master
592+ """
593+ return self ._is_master
594+
595+ @property
596+ def master_hostname (self ): # type: () -> str
597+ """Returns the hostname of the master node
598+ """
599+ return self ._master_hostname
600+
583601 @property
584602 def job_name (self ): # type: () -> str
585603 """The name of the current training job.
@@ -625,16 +643,19 @@ def to_env_vars(self):
625643 """
626644
627645 env = {
628- 'hosts' : self .hosts , 'network_interface_name' : self .network_interface_name ,
629- 'hps' : self .hyperparameters , 'user_entry_point' : self .user_entry_point ,
646+ 'hosts' : self .hosts , 'network_interface_name' : self .network_interface_name ,
647+ 'hps' : self .hyperparameters , 'user_entry_point' : self .user_entry_point ,
630648 'framework_params' : self .additional_framework_parameters ,
631- 'resource_config' : self .resource_config , 'input_data_config' : self .input_data_config ,
632- 'output_data_dir' : self .output_data_dir , 'channels' : sorted (self .channel_input_dirs .keys ()),
633- 'current_host' : self .current_host , 'module_name' : self .module_name , 'log_level' : self .log_level ,
649+ 'resource_config' : self .resource_config , 'input_data_config' : self .input_data_config ,
650+ 'output_data_dir' : self .output_data_dir ,
651+ 'channels' : sorted (self .channel_input_dirs .keys ()),
652+ 'current_host' : self .current_host , 'module_name' : self .module_name ,
653+ 'log_level' : self .log_level ,
634654 'framework_module' : self .framework_module , 'input_dir' : self .input_dir ,
635- 'input_config_dir' : self .input_config_dir , 'output_dir' : self .output_dir , 'num_cpus' : self .num_cpus ,
636- 'num_gpus' : self .num_gpus , 'model_dir' : self .model_dir , 'module_dir' : self .module_dir ,
637- 'training_env' : dict (self ), 'user_args' : self .to_cmd_args (),
655+ 'input_config_dir' : self .input_config_dir , 'output_dir' : self .output_dir ,
656+ 'num_cpus' : self .num_cpus ,
657+ 'num_gpus' : self .num_gpus , 'model_dir' : self .model_dir , 'module_dir' : self .module_dir ,
658+ 'training_env' : dict (self ), 'user_args' : self .to_cmd_args (),
638659 'output_intermediate_dir' : self .output_intermediate_dir
639660 }
640661
0 commit comments