Skip to content

Commit ac8bc67

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
correctly set log level for benchmark runs (meta-pytorch#3494)
Summary: Pull Request resolved: meta-pytorch#3494 # context * loglevel is not correctly set in train pipeline benchmark due to the multiprocess setup * the log level is only set in the main process but not correctly set in the forked/spawn processes * this diff add the `loglevel` argument into the RunConfig so that in every runner funcion can call `set_logger_level` * also directly pass the error message on yaml or json parser failure, which previously just warn silently and the warning message is buried in lengthy logs. * with loglevel=info we can now see the planner info: P2014482201 Reviewed By: spmex Differential Revision: D85829837 fbshipit-source-id: 9719baf4307972a1794bf8870cd5c2df8add4436
1 parent 75dfb7f commit ac8bc67

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

torchrec/distributed/benchmark/base.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -425,17 +425,11 @@ def _load_config_file(
425425
if not config_path:
426426
return {}
427427

428-
try:
429-
with open(config_path, "r") as f:
430-
if is_json:
431-
return json.load(f) or {}
432-
else:
433-
return yaml.safe_load(f) or {}
434-
except Exception as e:
435-
logger.error(
436-
f"Failed to load config because {e}. Proceeding without it."
437-
)
438-
return {}
428+
with open(config_path, "r") as f:
429+
if is_json:
430+
return json.load(f) or {}
431+
else:
432+
return yaml.safe_load(f) or {}
439433

440434
@functools.wraps(func)
441435
def wrapper() -> Any: # pyre-ignore [3]
@@ -479,7 +473,12 @@ def wrapper() -> Any: # pyre-ignore [3]
479473
# Merge the two dictionaries, JSON overrides YAML
480474
merged_defaults = {**yaml_defaults, **json_defaults}
481475

482-
seen_args = set() # track all --<name> we've added
476+
# track all --<name> we've added
477+
seen_args = {
478+
"json_config",
479+
"yaml_config",
480+
"loglevel",
481+
}
483482

484483
for _name, param in sig.parameters.items():
485484
cls = param.annotation
@@ -548,7 +547,12 @@ def wrapper() -> Any: # pyre-ignore [3]
548547
logger.info(config_instance)
549548

550549
loglevel = logging._nameToLevel[args.loglevel.upper()]
551-
logger.setLevel(loglevel)
550+
# Set loglevel for all existing loggers
551+
for existing_logger_name in logging.root.manager.loggerDict:
552+
existing_logger = logging.getLogger(existing_logger_name)
553+
existing_logger.setLevel(loglevel)
554+
# Also set the root logger
555+
logging.root.setLevel(loglevel)
552556

553557
return func(**kwargs)
554558

@@ -857,6 +861,7 @@ class BenchFuncConfig:
857861
export_stacks: bool = False
858862
all_rank_traces: bool = False
859863
memory_snapshot: bool = False
864+
loglevel: str = "WARNING"
860865

861866
# pyre-ignore [2]
862867
def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
@@ -873,6 +878,10 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
873878
"memory_snapshot": self.memory_snapshot,
874879
} | kwargs_to_override
875880

881+
def set_log_level(self) -> None:
882+
loglevel = logging._nameToLevel[self.loglevel.upper()]
883+
logging.root.setLevel(loglevel)
884+
876885

877886
def benchmark_func(
878887
name: str,

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def runner(
129129
torch.cuda.is_available() and torch.cuda.device_count() >= world_size
130130
), "CUDA not available or insufficient GPUs for the requested world_size"
131131

132-
torch.autograd.set_detect_anomaly(True)
132+
run_option.set_log_level()
133133
with MultiProcessContext(
134134
rank=rank,
135135
world_size=world_size,

0 commit comments

Comments
 (0)