diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index d03699282b..a572de0738 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1722,6 +1722,119 @@ def _get_tensor_memory(self, tensor_name: str) -> int: tensor = getattr(self, tensor_name) return tensor.numel() * tensor.element_size() + def _categorize_memory_by_location( + self, tensor_names: list[str] + ) -> tuple[int, int]: + """Categorize memory into HBM and UVM for given tensors. + + Returns: + (hbm_bytes, uvm_bytes) + """ + uvm_set = set(self._uvm_tensors_log) + hbm_bytes = 0 + uvm_bytes = 0 + + for name in tensor_names: + size = self._get_tensor_memory(name) + if name in uvm_set: + uvm_bytes += size + else: + hbm_bytes += size + + return hbm_bytes, uvm_bytes + + def _report_hbm_breakdown( + self, + stats_reporter: TBEStatsReporter, + embeddings: int, + optimizer_states: int, + cache: int, + total_static_sparse: int, + ephemeral: int, + ) -> None: + """Report HBM memory breakdown to stats reporter.""" + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.hbm.embeddings", + data_bytes=embeddings, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.hbm.optimizer_states", + data_bytes=optimizer_states, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.hbm.cache", + data_bytes=cache, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.hbm.total_static_sparse", + data_bytes=total_static_sparse, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.hbm.ephemeral", + data_bytes=ephemeral, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + + def _report_uvm_breakdown( + self, + stats_reporter: TBEStatsReporter, + embeddings: int, + optimizer_states: int, + cache: int, + total_static_sparse: int, + ephemeral: int, + ) -> None: + """Report UVM memory breakdown to stats reporter.""" + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.uvm.embeddings", + data_bytes=embeddings, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.uvm.optimizer_states", + data_bytes=optimizer_states, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.uvm.cache", + data_bytes=cache, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.uvm.total_static_sparse", + data_bytes=total_static_sparse, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="tbe.uvm.ephemeral", + data_bytes=ephemeral, + embedding_id=self.logging_table_name, + tbe_id=self.uuid, + ) + @torch.jit.ignore def _report_tbe_mem_usage(self) -> None: if self.stats_reporter is None: @@ -1731,10 +1844,12 @@ def _report_tbe_mem_usage(self) -> None: if not stats_reporter.should_report(self.step): return + # Calculate total memory from all parameters and buffers (always needed) total_mem_usage = sum( p.numel() * p.element_size() for p in self.parameters() ) + sum(b.numel() * b.element_size() for b in self.buffers()) + # Calculate total HBM and UVM usage (always needed) if self.use_cpu: total_hbm_usage = 0 total_uvm_usage = total_mem_usage @@ -1746,6 +1861,7 @@ def _report_tbe_mem_usage(self) -> None: ) total_hbm_usage = total_mem_usage - total_uvm_usage + # Report total memory usage metrics (always reported for backward compatibility) stats_reporter.report_data_amount( iteration_step=self.step, event_name="tbe.total_hbm_usage", @@ -1761,6 +1877,76 @@ def _report_tbe_mem_usage(self) -> None: tbe_id=self.uuid, ) + # Check if detailed memory breakdown is enabled via environment variable + # Set FBGEMM_TBE_MEM_BREAKDOWN=1 to enable expensive detailed breakdown + enable_detailed_breakdown = ( + int(os.environ.get("FBGEMM_TBE_MEM_BREAKDOWN", "0")) == 1 + ) + + if not enable_detailed_breakdown: + return + + # Tensor groups for sparse memory categorization + weight_tensors = ["weights_dev", "weights_host", "weights_uvm"] + optimizer_tensors = [ + "momentum1_dev", + "momentum1_host", + "momentum1_uvm", + "momentum2_dev", + "momentum2_host", + "momentum2_uvm", + ] + cache_tensors = [ + "lxu_cache_weights", + "lxu_cache_state", + "lxu_state", + "cache_hash_size_cumsum", + "cache_index_table_map", + "cache_miss_counter", + "lxu_cache_locking_counter", + ] + + # Calculate total memory for each component + weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors) + optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors) + cache_total = sum(self._get_tensor_memory(t) for t in cache_tensors) + + # Categorize memory by location (HBM vs UVM) + if self.use_cpu: + weights_hbm, weights_uvm = 0, weights_total + opt_hbm, opt_uvm = 0, optimizer_total + cache_hbm, cache_uvm = 0, cache_total + else: + weights_hbm, weights_uvm = self._categorize_memory_by_location( + weight_tensors + ) + opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors) + cache_hbm, cache_uvm = self._categorize_memory_by_location(cache_tensors) + + # Calculate ephemeral memory split between HBM and UVM + static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm + static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm + ephemeral_hbm = total_hbm_usage - static_sparse_hbm + ephemeral_uvm = total_uvm_usage - static_sparse_uvm + + # Report granular memory breakdowns + self._report_hbm_breakdown( + stats_reporter, + weights_hbm, + opt_hbm, + cache_hbm, + static_sparse_hbm, + ephemeral_hbm, + ) + self._report_uvm_breakdown( + stats_reporter, + weights_uvm, + opt_uvm, + cache_uvm, + static_sparse_uvm, + ephemeral_uvm, + ) + @torch.jit.ignore def _report_io_size_count(self, event: str, data: Tensor) -> Tensor: if self.stats_reporter is None: