Skip to content
This repository was archived by the owner on Aug 26, 2020. It is now read-only.

Commit 74e0019

Browse files
authored
Generic OpenMPI support (#157)
* Generic MPI - Create library to changehostname when mpi starts (#153)
1 parent faa3b3a commit 74e0019

32 files changed

+1245
-161
lines changed

.flake8

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
2-
application_import_names = sagemaker_containers, test
3-
import-order-style = google
2+
application_import_names = sagemaker_containers, test, libchangehostname
3+
import-order-style = google

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,4 +485,4 @@ SM_TRAINING_ENV='
485485
Provides the entire training information as a JSON encoded dictionary.
486486
## License
487487

488-
This library is licensed under the Apache 2.0 License.
488+
This library is licensed under the Apache 2.0 License.

setup.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,31 @@
1616
import os
1717
import sys
1818

19-
from setuptools import find_packages, setup
19+
import setuptools
2020

2121

2222
def read(file_name):
2323
return open(os.path.join(os.path.dirname(__file__), file_name)).read()
2424

2525

26-
packages = find_packages(where='src', exclude=('test',))
26+
packages = setuptools.find_packages(where='src', exclude=('test',))
2727
packages.append('sagemaker_containers.etc')
2828

2929
required_packages = [
30-
'boto3', 'six', 'pip', 'flask', 'gunicorn', 'gevent', 'inotify_simple', 'werkzeug'
30+
'numpy', 'boto3', 'six', 'pip', 'flask', 'gunicorn', 'typing',
31+
'gevent', 'inotify_simple', 'werkzeug', 'paramiko==2.4.2', 'psutil==5.4.8'
3132
]
3233

3334
# enum is introduced in Python 3.4. Installing enum back port
3435
if sys.version_info < (3, 4):
3536
required_packages.append('enum34 >= 1.1.6')
3637

37-
setup(
38+
gethostname = setuptools.Extension('libchangehostname',
39+
sources=['src/sagemaker_containers/c/libchangehostname.c'],
40+
extra_compile_args=['-Wall', '-shared', '-export-dynamic',
41+
'-ldl'])
42+
43+
setuptools.setup(
3844
name='sagemaker_containers',
3945
version='2.3.5',
4046
description='Open source library for creating containers to run on Amazon SageMaker.',
@@ -46,6 +52,7 @@ def read(file_name):
4652
},
4753
package_data={'sagemaker_containers.etc': ['*']},
4854
py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob('src/*.py')],
55+
ext_modules=[gethostname],
4956
long_description=read('README.md'),
5057
author='Amazon Web Services',
5158
url='https://github.com/aws/sagemaker-containers/',
@@ -64,11 +71,11 @@ def read(file_name):
6471
install_requires=required_packages,
6572

6673
extras_require={
67-
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'mock', 'sagemaker', 'numpy']
74+
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'mock', 'sagemaker==1.16.2']
6875
},
6976

7077
entry_points={
71-
'console_scripts': ['serve=sagemaker_containers.cli.serve:main',
72-
'train=sagemaker_containers.cli.train:main'],
78+
'console_scripts': ['serve=sagemaker_containers.cli.serve:main',
79+
'train=sagemaker_containers.cli.train:main'],
7380
}
7481
)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import enum
14+
import os
15+
16+
17+
class _EntryPointType(enum.Enum):
18+
PYTHON_PACKAGE = 'PYTHON_PACKAGE'
19+
PYTHON_PROGRAM = 'PYTHON_PROGRAM'
20+
COMMAND = 'COMMAND'
21+
22+
23+
PYTHON_PACKAGE = _EntryPointType.PYTHON_PACKAGE
24+
PYTHON_PROGRAM = _EntryPointType.PYTHON_PROGRAM
25+
COMMAND = _EntryPointType.COMMAND
26+
27+
28+
def get(path, name): # type: (str, str) -> _EntryPointType
29+
"""
30+
Args:
31+
path (string): Directory where the entry point is located
32+
name (string): Name of the entry point file
33+
34+
Returns:
35+
(_EntryPointType): The type of the entry point
36+
"""
37+
if 'setup.py' in os.listdir(path):
38+
return _EntryPointType.PYTHON_PACKAGE
39+
elif name.endswith('.py'):
40+
return _EntryPointType.PYTHON_PROGRAM
41+
else:
42+
return _EntryPointType.COMMAND

src/sagemaker_containers/_env.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def _set_base_path_env(): # type: () -> None
138138
str: the path to the intermediate output directory, e.g. /opt/ml/output/intermediate.
139139
"""
140140

141-
142141
HYPERPARAMETERS_FILE = 'hyperparameters.json' # type: str
143142
RESOURCE_CONFIG_FILE = 'resourceconfig.json' # type: str
144143
INPUT_DATA_CONFIG_FILE = 'inputdataconfig.json' # type: str
@@ -164,7 +163,7 @@ def _create_training_directories():
164163

165164
resources_dict = {
166165
"current_host": host_name,
167-
"hosts": [host_name]
166+
"hosts": [host_name]
168167
}
169168
_write_json(resources_dict, resource_config_file_dir)
170169

@@ -534,11 +533,11 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
534533
resource_config = resource_config or read_resource_config()
535534
current_host = resource_config['current_host']
536535
hosts = resource_config['hosts']
537-
network_interface_name = resource_config.get('network_interface_name', 'ethwe')
538536
input_data_config = input_data_config or read_input_data_config()
539537

540538
all_hyperparameters = hyperparameters or read_hyperparameters()
541-
split_result = _mapping.split_by_criteria(all_hyperparameters, keys=_params.SAGEMAKER_HYPERPARAMETERS,
539+
split_result = _mapping.split_by_criteria(all_hyperparameters,
540+
keys=_params.SAGEMAKER_HYPERPARAMETERS,
542541
prefix=_params.SAGEMAKER_PREFIX)
543542

544543
sagemaker_hyperparameters = split_result.included
@@ -547,14 +546,17 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
547546
if k not in _params.SAGEMAKER_HYPERPARAMETERS
548547
}
549548

550-
sagemaker_region = sagemaker_hyperparameters.get(_params.REGION_NAME_PARAM, boto3.session.Session().region_name)
549+
sagemaker_region = sagemaker_hyperparameters.get(_params.REGION_NAME_PARAM,
550+
boto3.session.Session().region_name)
551551

552552
os.environ[_params.JOB_NAME_ENV] = sagemaker_hyperparameters.get(_params.JOB_NAME_PARAM, '')
553553
os.environ[_params.CURRENT_HOST_ENV] = current_host
554554
os.environ[_params.REGION_NAME_ENV] = sagemaker_region or ''
555555

556556
self._hosts = hosts
557-
self._network_interface_name = network_interface_name
557+
558+
self._network_interface_name = resource_config.get('network_interface_name', 'eth0')
559+
558560
self._hyperparameters = split_result.excluded
559561
self._additional_framework_parameters = additional_framework_parameters
560562
self._resource_config = resource_config
@@ -567,7 +569,8 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
567569
# override base class attributes
568570
if self._module_name is None:
569571
self._module_name = str(sagemaker_hyperparameters.get(_params.USER_PROGRAM_PARAM, None))
570-
self._user_entry_point = self._user_entry_point or sagemaker_hyperparameters.get(_params.USER_PROGRAM_PARAM)
572+
self._user_entry_point = self._user_entry_point or sagemaker_hyperparameters.get(
573+
_params.USER_PROGRAM_PARAM)
571574

572575
self._module_dir = str(sagemaker_hyperparameters.get(_params.SUBMIT_DIR_PARAM, code_dir))
573576
self._log_level = sagemaker_hyperparameters.get(_params.LOG_LEVEL_PARAM, logging.INFO)
@@ -580,6 +583,21 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
580583
self._output_dir = output_dir
581584
self._job_name = os.environ.get(_params.TRAINING_JOB_ENV.upper(), None)
582585

586+
self._master_hostname = list(hosts)[0]
587+
self._is_master = current_host == self._master_hostname
588+
589+
@property
590+
def is_master(self): # type: () -> bool
591+
"""Returns True if host is master
592+
"""
593+
return self._is_master
594+
595+
@property
596+
def master_hostname(self): # type: () -> str
597+
"""Returns the hostname of the master node
598+
"""
599+
return self._master_hostname
600+
583601
@property
584602
def job_name(self): # type: () -> str
585603
"""The name of the current training job.
@@ -625,16 +643,19 @@ def to_env_vars(self):
625643
"""
626644

627645
env = {
628-
'hosts': self.hosts, 'network_interface_name': self.network_interface_name,
629-
'hps': self.hyperparameters, 'user_entry_point': self.user_entry_point,
646+
'hosts': self.hosts, 'network_interface_name': self.network_interface_name,
647+
'hps': self.hyperparameters, 'user_entry_point': self.user_entry_point,
630648
'framework_params': self.additional_framework_parameters,
631-
'resource_config': self.resource_config, 'input_data_config': self.input_data_config,
632-
'output_data_dir': self.output_data_dir, 'channels': sorted(self.channel_input_dirs.keys()),
633-
'current_host': self.current_host, 'module_name': self.module_name, 'log_level': self.log_level,
649+
'resource_config': self.resource_config, 'input_data_config': self.input_data_config,
650+
'output_data_dir': self.output_data_dir,
651+
'channels': sorted(self.channel_input_dirs.keys()),
652+
'current_host': self.current_host, 'module_name': self.module_name,
653+
'log_level': self.log_level,
634654
'framework_module': self.framework_module, 'input_dir': self.input_dir,
635-
'input_config_dir': self.input_config_dir, 'output_dir': self.output_dir, 'num_cpus': self.num_cpus,
636-
'num_gpus': self.num_gpus, 'model_dir': self.model_dir, 'module_dir': self.module_dir,
637-
'training_env': dict(self), 'user_args': self.to_cmd_args(),
655+
'input_config_dir': self.input_config_dir, 'output_dir': self.output_dir,
656+
'num_cpus': self.num_cpus,
657+
'num_gpus': self.num_gpus, 'model_dir': self.model_dir, 'module_dir': self.module_dir,
658+
'training_env': dict(self), 'user_args': self.to_cmd_args(),
638659
'output_intermediate_dir': self.output_intermediate_dir
639660
}
640661

0 commit comments

Comments
 (0)