Skip to content

Commit 20f01a5

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
recmetrics for fault tolerance (pytorch#2981)
Summary: Pull Request resolved: pytorch#2981 in this diff we expose two new methods for metric module: get_pre_compute_states and load_pre_compute_states. like the name suggests this function returns the metric states before the metric.compute() call. we aggregate the states and apply the reduction function to those intermediate states and return the result. will add fused task aggregation in a later diff. this diff supports unfused task computation only. usage: pass in the devicemesh from DMPCollection or can pass in DMPCollection.sharding_pg example output is setup as: ``` ne: { task_name: {state : tensor, state : tensor}, task_name : {state : tensor, state : tensor} }, mse: ... ``` real output: ``` 'ne': {'DefaultTask': {'cross_entropy_sum': tensor([18453.3014], device='cuda:1', dtype=torch.float64), 'window_cross_entropy_sum': tensor([18453.3014], device='cuda:1', dtype=torch.float64), 'weighted_num_samples': tensor([12778.6176], device='cuda:1', dtype=torch.float64), 'window_weighted_num_samples': tensor([12778.6176], device='cuda:1', dtype=torch.float64), 'pos_labels': tensor([6420.7206], device='cuda:1', dtype=torch.float64), 'window_pos_labels': tensor([6420.7206], device='cuda:1', dtype=torch.float64), 'neg_labels': tensor([6357.8971], device='cuda:1', dtype=torch.float64), 'window_neg_labels': tensor([6357.8971], device='cuda:1', dtype=torch.float64)}}} ``` Reviewed By: peterfu0 Differential Revision: D71934160 fbshipit-source-id: a1b79274dba9f46f2f61df4edd3a16c21b6a91bc
1 parent a2b1ee6 commit 20f01a5

File tree

3 files changed

+214
-1
lines changed

3 files changed

+214
-1
lines changed

torchrec/metrics/metric_module.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import torch.distributed as dist
1919
import torch.nn as nn
20+
from torch.distributed.tensor import DeviceMesh
2021
from torch.profiler import record_function
2122
from torchrec.metrics.accuracy import AccuracyMetric
2223
from torchrec.metrics.auc import AUCMetric
@@ -354,6 +355,125 @@ def reset(self) -> None:
354355
def get_required_inputs(self) -> Optional[List[str]]:
355356
return self.rec_metrics.get_required_inputs()
356357

358+
def _get_throughput_metric_states(
359+
self, metric: ThroughputMetric
360+
) -> Dict[str, Dict[str, torch.Tensor]]:
361+
states = {}
362+
# this doesn't use `state_dict` as some buffers are not persistent
363+
for name, buf in metric.named_buffers():
364+
states[name] = buf
365+
return {metric._metric_name.value: states}
366+
367+
def _get_metric_states(
368+
self,
369+
metric: RecMetric,
370+
world_size: int,
371+
process_group: Union[dist.ProcessGroup, DeviceMesh],
372+
) -> Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]:
373+
metric_computations = metric._metrics_computations
374+
tasks = metric._tasks
375+
376+
state_aggregated = {}
377+
for task, metric_computation in zip(tasks, metric_computations):
378+
inputs = []
379+
state_aggregated[task.name] = {}
380+
for attr, reduction_fn in metric_computation._reductions.items():
381+
inputs.append((attr, getattr(metric_computation, attr), reduction_fn))
382+
383+
# TODO: do one all gather call per metric, instead of one per state
384+
# this may require more logic as shapes of states are not guranteed to be same
385+
# may need padding
386+
for state, tensor, reduction_fn in inputs:
387+
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
388+
dist.all_gather(gather_list, tensor, group=process_group)
389+
state_aggregated[task.name][state] = (
390+
reduction_fn(torch.stack(gather_list))
391+
if reduction_fn is not None
392+
else gather_list
393+
)
394+
395+
return state_aggregated
396+
397+
def get_pre_compute_states(
398+
self, pg: Union[dist.ProcessGroup, DeviceMesh]
399+
) -> Dict[str, Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]]:
400+
"""
401+
This function returns the states per rank for each metric to be saved. The states are are aggregated by the state defined reduction_function.
402+
This can be optionall disabled by setting ``reduce_metrics`` to False. The output on each rank is identical.
403+
404+
Each metric has N number of tasks associated with it. This is reflected in the metric state, where the size of the tensor is
405+
typically (n_tasks, 1). Depending on the `RecComputeMode` the metric is in, the number of tasks can be 1 or len(tasks).
406+
407+
The output of the data is defined as nested dictionary, a dict of ``metric._namespace`` each mapping to a dict of tasks and their states and associated tensors:
408+
metric : str -> { task : {state : tensor or list[tensor]} }
409+
410+
This differs from the state dict such that the metric states are gathered to all ranks within the process group and the reduction function is
411+
applied to them. Typical state dict exposes just the metric states that live on the rank it's called from.
412+
413+
Args:
414+
pg (Union[dist.ProcessGroup, DeviceMesh]): the process group to use for all gather.
415+
reduce_metrics (bool): whether to reduce the metrics or not. Default is True.
416+
417+
Returns:
418+
Dict[str, Dict[str, Dict[str, torch.Tensor]]]: the states for each metric to be saved
419+
"""
420+
if isinstance(pg, DeviceMesh):
421+
process_group: dist.ProcessGroup = pg.get_group(mesh_dim="shard")
422+
else:
423+
process_group: dist.ProcessGroup = pg
424+
aggregated_states = {}
425+
world_size = dist.get_world_size(
426+
process_group
427+
) # Under 2D parallel context, this should be sharding world size
428+
429+
for metric in self.rec_metrics.rec_metrics:
430+
aggregated_states[metric._namespace.value] = self._get_metric_states(
431+
metric,
432+
world_size,
433+
process_group,
434+
)
435+
436+
# throughput metric requires special handling, since it's not a RecMetric
437+
throughput_metric = self.throughput_metric
438+
if throughput_metric is not None:
439+
aggregated_states[throughput_metric._namespace.value] = (
440+
self._get_throughput_metric_states(throughput_metric)
441+
)
442+
443+
return aggregated_states
444+
445+
def load_pre_compute_states(
446+
self,
447+
source: Dict[
448+
str, Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]
449+
],
450+
) -> None:
451+
"""
452+
Load states from ``get_pre_compute_states``. This is called on every rank, no collectives are called in this function.
453+
454+
Args:
455+
source (Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]): the source states to load from. This
456+
is the output of ``get_pre_compute_states``.
457+
458+
Returns:
459+
None
460+
"""
461+
for metric in self.rec_metrics.rec_metrics:
462+
states = source[metric._namespace.value]
463+
for task, metric_computation in zip(
464+
metric._tasks, metric._metrics_computations
465+
):
466+
state = states[task.name]
467+
for attr, tensor in state.items():
468+
setattr(metric_computation, attr, tensor)
469+
470+
if self.throughput_metric is not None:
471+
states = source[self.throughput_metric._namespace.value][
472+
self.throughput_metric._metric_name.value # pyre-ignore[16]
473+
]
474+
for name, buf in self.throughput_metric.named_buffers(): # pyre-ignore[16]
475+
buf.copy_(states[name])
476+
357477

358478
def _generate_rec_metrics(
359479
metrics_config: MetricsConfig,

torchrec/metrics/rec_metric.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,10 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
479479
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value, metric_report.description
480480

481481
def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
482+
"""
483+
For each task, we generate an associated RecMetricComputation object for it.
484+
This would mean in the states of each RecMetricComputation object, the n_tasks dimension is 1.
485+
"""
482486
for task, metric_computation in zip(self._tasks, self._metrics_computations):
483487
metric_computation.pre_compute()
484488
for metric_report in getattr(
@@ -494,6 +498,7 @@ def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
494498
or metric_computation.has_valid_update[0] > 0
495499
else torch.zeros_like(metric_report.value)
496500
)
501+
# ultimately compute result comes here, and is then written to tensorboard, for fused tasks we need to know the metric prefix val and description
497502
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value, metric_report.description
498503

499504
def _fuse_update_buffers(self) -> Dict[str, RecModelOutput]:

torchrec/metrics/tests/test_metric_module.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
import torch
2020
import torch.distributed as dist
2121
import torch.distributed.launcher as pet
22+
from torchrec.distributed.test_utils.multi_process import (
23+
MultiProcessContext,
24+
MultiProcessTestBase,
25+
)
2226
from torchrec.metrics.auc import AUCMetric
2327
from torchrec.metrics.metric_module import (
2428
generate_metric_module,
@@ -32,14 +36,15 @@
3236
BatchSizeStage,
3337
DefaultMetricsConfig,
3438
DefaultTaskInfo,
35-
EmptyMetricsConfig,
39+
MetricsConfig,
3640
RecMetricDef,
3741
RecMetricEnum,
3842
)
3943
from torchrec.metrics.model_utils import parse_task_model_outputs
4044
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
4145
from torchrec.metrics.test_utils import gen_test_batch, get_launch_config
4246
from torchrec.metrics.throughput import ThroughputMetric
47+
from torchrec.test_utils import seed_and_log, skip_if_asan_class
4348

4449
METRIC_MODULE_PATH = "torchrec.metrics.metric_module"
4550

@@ -603,3 +608,86 @@ def test_save_and_load_state_dict(self) -> None:
603608
no_bss_metric_module.load_state_dict(state_dict)
604609
# Make sure num_batch wasn't created on the throughput module (and no exception was thrown above)
605610
self.assertFalse(hasattr(no_bss_metric_module.throughput_metric, "_num_batch"))
611+
612+
613+
def metric_module_gather_state(
614+
rank: int,
615+
world_size: int,
616+
backend: str,
617+
config: MetricsConfig,
618+
batch_size: int,
619+
local_size: Optional[int] = None,
620+
) -> None:
621+
"""
622+
We compare the computed values of the metric module using the get_pre_compute_states API.
623+
"""
624+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
625+
metric_module = generate_metric_module(
626+
TestMetricModule,
627+
metrics_config=config,
628+
batch_size=batch_size,
629+
world_size=world_size,
630+
my_rank=rank,
631+
state_metrics_mapping={},
632+
device=ctx.device,
633+
process_group=ctx.pg,
634+
)
635+
636+
test_batches = []
637+
for _ in range(100):
638+
test_batch = gen_test_batch(batch_size)
639+
for k in test_batch.keys():
640+
test_batch[k] = test_batch[k].to(ctx.device)
641+
# save to re run
642+
test_batches.append(test_batch)
643+
metric_module.update(test_batch)
644+
645+
computed_value = metric_module.compute()
646+
states = metric_module.get_pre_compute_states(pg=ctx.pg) # pyre-ignore[6]
647+
648+
torch.distributed.barrier(ctx.pg)
649+
# Compare to computing metrics on metric module that loads from pre_compute_states
650+
new_metric_module = generate_metric_module(
651+
TestMetricModule,
652+
metrics_config=config,
653+
batch_size=batch_size,
654+
world_size=1,
655+
my_rank=0,
656+
state_metrics_mapping={},
657+
device=torch.device(f"cuda:{rank}"),
658+
process_group=dist.new_group(ranks=[rank], backend="nccl"),
659+
)
660+
new_metric_module.load_pre_compute_states(states)
661+
new_computed_value = new_metric_module.compute()
662+
663+
for metric, tensor in computed_value.items():
664+
new_tensor = new_computed_value[metric]
665+
torch.testing.assert_close(tensor, new_tensor, check_device=False)
666+
667+
668+
@skip_if_asan_class
669+
class MetricModuleDistributedTest(MultiProcessTestBase):
670+
671+
@seed_and_log
672+
def setUp(self, backend: str = "nccl") -> None:
673+
super().setUp()
674+
self.backend = backend
675+
676+
if torch.cuda.is_available():
677+
self.device = torch.device("cuda")
678+
else:
679+
self.skipTest("CUDA required for distributed test")
680+
681+
def test_metric_module_gather_state(self) -> None:
682+
world_size = 2
683+
backend = "nccl"
684+
metrics_config = DefaultMetricsConfig
685+
batch_size = 128
686+
687+
self._run_multi_process_test(
688+
callable=metric_module_gather_state,
689+
world_size=world_size,
690+
backend=backend,
691+
batch_size=batch_size,
692+
config=metrics_config,
693+
)

0 commit comments

Comments
 (0)