Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions torchrec/metrics/cpu_offloaded_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MetricUpdateJob,
SynchronizationMarker,
)
from torchrec.metrics.metric_module import MetricValue, RecMetricModule
from torchrec.metrics.metric_module import MetricsFuture, MetricValue, RecMetricModule
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.rec_metric import RecMetricException
Expand Down Expand Up @@ -254,24 +254,24 @@ def compute(self) -> Dict[str, MetricValue]:
)

@override
def async_compute(
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
) -> None:
def async_compute(self) -> MetricsFuture:
"""
Entry point for asynchronous metric compute. It enqueues a synchronization marker
to the update queue.

Args:
future: Pre-created future where the computed metrics will be set.
"""
metrics_future = concurrent.futures.Future()
if self._shutdown_event.is_set():
future.set_exception(
metrics_future.set_exception(
RecMetricException("metric processor thread is shut down.")
)
return
return metrics_future

self.update_queue.put_nowait(SynchronizationMarker(future))
self.update_queue.put_nowait(SynchronizationMarker(metrics_future))
self.update_queue_size_logger.add(self.update_queue.qsize())
return metrics_future

def _process_synchronization_marker(
self, synchronization_marker: SynchronizationMarker
Expand Down
5 changes: 2 additions & 3 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@


MetricValue = Union[torch.Tensor, float]
MetricsFuture = concurrent.futures.Future[Dict[str, MetricValue]]


class StateMetric(abc.ABC):
Expand Down Expand Up @@ -490,9 +491,7 @@ def load_pre_compute_states(
def shutdown(self) -> None:
logger.info("Initiating graceful shutdown...")

def async_compute(
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
) -> None:
def async_compute(self) -> MetricsFuture:
raise RecMetricException("async_compute is not supported in RecMetricModule")


Expand Down
18 changes: 4 additions & 14 deletions torchrec/metrics/tests/test_cpu_offloaded_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,6 @@ def test_async_compute_synchronization_marker(self) -> None:

Note that the comms module's metrics are actually the ones that are computed.
"""
future: concurrent.futures.Future[Dict[str, MetricValue]] = (
concurrent.futures.Future()
)

model_out = {
"task1-prediction": torch.tensor([0.5]),
"task1-label": torch.tensor([0.7]),
Expand All @@ -220,7 +216,7 @@ def test_async_compute_synchronization_marker(self) -> None:
for _ in range(10):
self.cpu_module.update(model_out)

self.cpu_module.async_compute(future)
self.cpu_module.async_compute()

comms_mock_metric = cast(
MockRecMetric, self.cpu_module.comms_module.rec_metrics.rec_metrics[0]
Expand All @@ -234,10 +230,7 @@ def test_async_compute_synchronization_marker(self) -> None:
def test_async_compute_after_shutdown(self) -> None:
self.cpu_module.shutdown()

future: concurrent.futures.Future[Dict[str, MetricValue]] = (
concurrent.futures.Future()
)
self.cpu_module.async_compute(future)
future = self.cpu_module.async_compute()

self.assertRaisesRegex(
RecMetricException, "metric processor thread is shut down.", future.result
Expand Down Expand Up @@ -275,7 +268,7 @@ def test_wait_until_queue_is_empty(self) -> None:
"task1-weight": torch.tensor([1.0]),
}
self.cpu_module.update(model_out)
self.cpu_module.async_compute(concurrent.futures.Future())
self.cpu_module.async_compute()

self.cpu_module.wait_until_queue_is_empty(self.cpu_module.update_queue)
self.cpu_module.wait_until_queue_is_empty(self.cpu_module.compute_queue)
Expand Down Expand Up @@ -576,10 +569,7 @@ def _compare_metric_results_worker(

standard_results = standard_module.compute()

future: concurrent.futures.Future[Dict[str, MetricValue]] = (
concurrent.futures.Future()
)
cpu_offloaded_module.async_compute(future)
future = cpu_offloaded_module.async_compute()

# Wait for async compute to finish. Compare the input to each update()
offloaded_results = future.result(timeout=10.0)
Expand Down
2 changes: 1 addition & 1 deletion torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def test_async_compute_raises_exception(self) -> None:
RecMetricException,
"async_compute is not supported in RecMetricModule",
):
metric_module.async_compute(concurrent.futures.Future())
metric_module.async_compute()


def metric_module_gather_state(
Expand Down
Loading