Skip to content

Commit e3d5e36

Browse files
SSYernarfacebook-github-bot
authored andcommitted
JSON config support for cmd_conf utility (#3227)
Summary: Pull Request resolved: #3227 Added a support for JSON file configuration of cmd_conf decorator for benchmark parameters parsing. This feature makes easier to reproduce complex configurations without the need to pass CLI arguments. Example .json file should look like: ``` { "RunOptions": { "world_size": 1, "num_batches": 3 }, "PipelineConfig": { "pipeline": "sparse" } } ``` Also, configs can be listed in a 'flat' way as well: ``` { "world_size": 2, "num_batches": 3, "pipeline": "base" } ``` To run, add `--json_config` flag with the .json file path. JSON configs will override YAML configs if both are given. Reviewed By: aliafzal Differential Revision: D78833153 fbshipit-source-id: 9be6b0e8945497f4f4e9f5c3cf96bd79ff03ea9d
1 parent 2e874f7 commit e3d5e36

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,20 @@ def set_embedding_config(
536536
# pyre-ignore [24]
537537
def cmd_conf(func: Callable) -> Callable:
538538

539+
def _load_config_file(config_path: str, is_json: bool = False) -> Dict[str, Any]:
540+
if not config_path:
541+
return {}
542+
543+
try:
544+
with open(config_path, "r") as f:
545+
if is_json:
546+
return json.load(f) or {}
547+
else:
548+
return yaml.safe_load(f) or {}
549+
except Exception as e:
550+
logger.warning(f"Failed to load config because {e}. Proceeding without it.")
551+
return {}
552+
539553
# pyre-ignore [3]
540554
def wrapper() -> Any:
541555
sig = inspect.signature(func)
@@ -548,6 +562,13 @@ def wrapper() -> Any:
548562
help="YAML config file for benchmarking",
549563
)
550564

565+
parser.add_argument(
566+
"--json_config",
567+
type=str,
568+
default=None,
569+
help="JSON config file for benchmarking",
570+
)
571+
551572
# Add loglevel argument with current logger level as default
552573
parser.add_argument(
553574
"--loglevel",
@@ -558,18 +579,18 @@ def wrapper() -> Any:
558579

559580
pre_args, _ = parser.parse_known_args()
560581

561-
yaml_defaults: Dict[str, Any] = {}
562-
if pre_args.yaml_config:
563-
try:
564-
with open(pre_args.yaml_config, "r") as f:
565-
yaml_defaults = yaml.safe_load(f) or {}
566-
logger.info(
567-
f"Loaded YAML config from {pre_args.yaml_config}: {yaml_defaults}"
568-
)
569-
except Exception as e:
570-
logger.warning(
571-
f"Failed to load YAML config because {e}. Proceeding without it."
572-
)
582+
yaml_defaults: Dict[str, Any] = (
583+
_load_config_file(pre_args.yaml_config, is_json=False)
584+
if pre_args.yaml_config
585+
else {}
586+
)
587+
json_defaults: Dict[str, Any] = (
588+
_load_config_file(pre_args.json_config, is_json=True)
589+
if pre_args.json_config
590+
else {}
591+
)
592+
# Merge the two dictionaries, JSON overrides YAML
593+
merged_defaults = {**yaml_defaults, **json_defaults}
573594

574595
seen_args = set() # track all --<name> we've added
575596

@@ -595,10 +616,10 @@ def wrapper() -> Any:
595616
ftype = non_none[0]
596617
origin = get_origin(ftype)
597618

598-
# Handle default_factory value and allow YAML config to override it
599-
default_value = yaml_defaults.get(
619+
# Handle default_factory value and allow config to override
620+
default_value = merged_defaults.get(
600621
arg_name, # flat lookup
601-
yaml_defaults.get(cls.__name__, {}).get( # hierarchy lookup
622+
merged_defaults.get(cls.__name__, {}).get( # hierarchy lookup
602623
arg_name,
603624
(
604625
f.default_factory() # pyre-ignore [29]

0 commit comments

Comments
 (0)