|
17 | 17 | import torch
|
18 | 18 | import torch.distributed as dist
|
19 | 19 | import torch.nn as nn
|
| 20 | +from torch.distributed.tensor import DeviceMesh |
20 | 21 | from torch.profiler import record_function
|
21 | 22 | from torchrec.metrics.accuracy import AccuracyMetric
|
22 | 23 | from torchrec.metrics.auc import AUCMetric
|
@@ -354,6 +355,125 @@ def reset(self) -> None:
|
354 | 355 | def get_required_inputs(self) -> Optional[List[str]]:
|
355 | 356 | return self.rec_metrics.get_required_inputs()
|
356 | 357 |
|
| 358 | + def _get_throughput_metric_states( |
| 359 | + self, metric: ThroughputMetric |
| 360 | + ) -> Dict[str, Dict[str, torch.Tensor]]: |
| 361 | + states = {} |
| 362 | + # this doesn't use `state_dict` as some buffers are not persistent |
| 363 | + for name, buf in metric.named_buffers(): |
| 364 | + states[name] = buf |
| 365 | + return {metric._metric_name.value: states} |
| 366 | + |
| 367 | + def _get_metric_states( |
| 368 | + self, |
| 369 | + metric: RecMetric, |
| 370 | + world_size: int, |
| 371 | + process_group: Union[dist.ProcessGroup, DeviceMesh], |
| 372 | + ) -> Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]: |
| 373 | + metric_computations = metric._metrics_computations |
| 374 | + tasks = metric._tasks |
| 375 | + |
| 376 | + state_aggregated = {} |
| 377 | + for task, metric_computation in zip(tasks, metric_computations): |
| 378 | + inputs = [] |
| 379 | + state_aggregated[task.name] = {} |
| 380 | + for attr, reduction_fn in metric_computation._reductions.items(): |
| 381 | + inputs.append((attr, getattr(metric_computation, attr), reduction_fn)) |
| 382 | + |
| 383 | + # TODO: do one all gather call per metric, instead of one per state |
| 384 | + # this may require more logic as shapes of states are not guranteed to be same |
| 385 | + # may need padding |
| 386 | + for state, tensor, reduction_fn in inputs: |
| 387 | + gather_list = [torch.empty_like(tensor) for _ in range(world_size)] |
| 388 | + dist.all_gather(gather_list, tensor, group=process_group) |
| 389 | + state_aggregated[task.name][state] = ( |
| 390 | + reduction_fn(torch.stack(gather_list)) |
| 391 | + if reduction_fn is not None |
| 392 | + else gather_list |
| 393 | + ) |
| 394 | + |
| 395 | + return state_aggregated |
| 396 | + |
| 397 | + def get_pre_compute_states( |
| 398 | + self, pg: Union[dist.ProcessGroup, DeviceMesh] |
| 399 | + ) -> Dict[str, Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]]: |
| 400 | + """ |
| 401 | + This function returns the states per rank for each metric to be saved. The states are are aggregated by the state defined reduction_function. |
| 402 | + This can be optionall disabled by setting ``reduce_metrics`` to False. The output on each rank is identical. |
| 403 | +
|
| 404 | + Each metric has N number of tasks associated with it. This is reflected in the metric state, where the size of the tensor is |
| 405 | + typically (n_tasks, 1). Depending on the `RecComputeMode` the metric is in, the number of tasks can be 1 or len(tasks). |
| 406 | +
|
| 407 | + The output of the data is defined as nested dictionary, a dict of ``metric._namespace`` each mapping to a dict of tasks and their states and associated tensors: |
| 408 | + metric : str -> { task : {state : tensor or list[tensor]} } |
| 409 | +
|
| 410 | + This differs from the state dict such that the metric states are gathered to all ranks within the process group and the reduction function is |
| 411 | + applied to them. Typical state dict exposes just the metric states that live on the rank it's called from. |
| 412 | +
|
| 413 | + Args: |
| 414 | + pg (Union[dist.ProcessGroup, DeviceMesh]): the process group to use for all gather. |
| 415 | + reduce_metrics (bool): whether to reduce the metrics or not. Default is True. |
| 416 | +
|
| 417 | + Returns: |
| 418 | + Dict[str, Dict[str, Dict[str, torch.Tensor]]]: the states for each metric to be saved |
| 419 | + """ |
| 420 | + if isinstance(pg, DeviceMesh): |
| 421 | + process_group: dist.ProcessGroup = pg.get_group(mesh_dim="shard") |
| 422 | + else: |
| 423 | + process_group: dist.ProcessGroup = pg |
| 424 | + aggregated_states = {} |
| 425 | + world_size = dist.get_world_size( |
| 426 | + process_group |
| 427 | + ) # Under 2D parallel context, this should be sharding world size |
| 428 | + |
| 429 | + for metric in self.rec_metrics.rec_metrics: |
| 430 | + aggregated_states[metric._namespace.value] = self._get_metric_states( |
| 431 | + metric, |
| 432 | + world_size, |
| 433 | + process_group, |
| 434 | + ) |
| 435 | + |
| 436 | + # throughput metric requires special handling, since it's not a RecMetric |
| 437 | + throughput_metric = self.throughput_metric |
| 438 | + if throughput_metric is not None: |
| 439 | + aggregated_states[throughput_metric._namespace.value] = ( |
| 440 | + self._get_throughput_metric_states(throughput_metric) |
| 441 | + ) |
| 442 | + |
| 443 | + return aggregated_states |
| 444 | + |
| 445 | + def load_pre_compute_states( |
| 446 | + self, |
| 447 | + source: Dict[ |
| 448 | + str, Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]] |
| 449 | + ], |
| 450 | + ) -> None: |
| 451 | + """ |
| 452 | + Load states from ``get_pre_compute_states``. This is called on every rank, no collectives are called in this function. |
| 453 | +
|
| 454 | + Args: |
| 455 | + source (Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]): the source states to load from. This |
| 456 | + is the output of ``get_pre_compute_states``. |
| 457 | +
|
| 458 | + Returns: |
| 459 | + None |
| 460 | + """ |
| 461 | + for metric in self.rec_metrics.rec_metrics: |
| 462 | + states = source[metric._namespace.value] |
| 463 | + for task, metric_computation in zip( |
| 464 | + metric._tasks, metric._metrics_computations |
| 465 | + ): |
| 466 | + state = states[task.name] |
| 467 | + for attr, tensor in state.items(): |
| 468 | + setattr(metric_computation, attr, tensor) |
| 469 | + |
| 470 | + if self.throughput_metric is not None: |
| 471 | + states = source[self.throughput_metric._namespace.value][ |
| 472 | + self.throughput_metric._metric_name.value # pyre-ignore[16] |
| 473 | + ] |
| 474 | + for name, buf in self.throughput_metric.named_buffers(): # pyre-ignore[16] |
| 475 | + buf.copy_(states[name]) |
| 476 | + |
357 | 477 |
|
358 | 478 | def _generate_rec_metrics(
|
359 | 479 | metrics_config: MetricsConfig,
|
|
0 commit comments