Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,119 @@ def _get_tensor_memory(self, tensor_name: str) -> int:
tensor = getattr(self, tensor_name)
return tensor.numel() * tensor.element_size()

def _categorize_memory_by_location(
self, tensor_names: list[str]
) -> tuple[int, int]:
"""Categorize memory into HBM and UVM for given tensors.

Returns:
(hbm_bytes, uvm_bytes)
"""
uvm_set = set(self._uvm_tensors_log)
hbm_bytes = 0
uvm_bytes = 0

for name in tensor_names:
size = self._get_tensor_memory(name)
if name in uvm_set:
uvm_bytes += size
else:
hbm_bytes += size

return hbm_bytes, uvm_bytes

def _report_hbm_breakdown(
self,
stats_reporter: TBEStatsReporter,
embeddings: int,
optimizer_states: int,
cache: int,
total_static_sparse: int,
ephemeral: int,
) -> None:
"""Report HBM memory breakdown to stats reporter."""
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.hbm.embeddings",
data_bytes=embeddings,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.hbm.optimizer_states",
data_bytes=optimizer_states,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.hbm.cache",
data_bytes=cache,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.hbm.total_static_sparse",
data_bytes=total_static_sparse,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.hbm.ephemeral",
data_bytes=ephemeral,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)

def _report_uvm_breakdown(
self,
stats_reporter: TBEStatsReporter,
embeddings: int,
optimizer_states: int,
cache: int,
total_static_sparse: int,
ephemeral: int,
) -> None:
"""Report UVM memory breakdown to stats reporter."""
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.uvm.embeddings",
data_bytes=embeddings,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.uvm.optimizer_states",
data_bytes=optimizer_states,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.uvm.cache",
data_bytes=cache,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.uvm.total_static_sparse",
data_bytes=total_static_sparse,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.uvm.ephemeral",
data_bytes=ephemeral,
embedding_id=self.logging_table_name,
tbe_id=self.uuid,
)

@torch.jit.ignore
def _report_tbe_mem_usage(self) -> None:
if self.stats_reporter is None:
Expand All @@ -1731,10 +1844,12 @@ def _report_tbe_mem_usage(self) -> None:
if not stats_reporter.should_report(self.step):
return

# Calculate total memory from all parameters and buffers (always needed)
total_mem_usage = sum(
p.numel() * p.element_size() for p in self.parameters()
) + sum(b.numel() * b.element_size() for b in self.buffers())

# Calculate total HBM and UVM usage (always needed)
if self.use_cpu:
total_hbm_usage = 0
total_uvm_usage = total_mem_usage
Expand All @@ -1746,6 +1861,7 @@ def _report_tbe_mem_usage(self) -> None:
)
total_hbm_usage = total_mem_usage - total_uvm_usage

# Report total memory usage metrics (always reported for backward compatibility)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="tbe.total_hbm_usage",
Expand All @@ -1761,6 +1877,76 @@ def _report_tbe_mem_usage(self) -> None:
tbe_id=self.uuid,
)

# Check if detailed memory breakdown is enabled via environment variable
# Set FBGEMM_TBE_MEM_BREAKDOWN=1 to enable expensive detailed breakdown
enable_detailed_breakdown = (
int(os.environ.get("FBGEMM_TBE_MEM_BREAKDOWN", "0")) == 1
)

if not enable_detailed_breakdown:
return

# Tensor groups for sparse memory categorization
weight_tensors = ["weights_dev", "weights_host", "weights_uvm"]
optimizer_tensors = [
"momentum1_dev",
"momentum1_host",
"momentum1_uvm",
"momentum2_dev",
"momentum2_host",
"momentum2_uvm",
]
cache_tensors = [
"lxu_cache_weights",
"lxu_cache_state",
"lxu_state",
"cache_hash_size_cumsum",
"cache_index_table_map",
"cache_miss_counter",
"lxu_cache_locking_counter",
]

# Calculate total memory for each component
weights_total = sum(self._get_tensor_memory(t) for t in weight_tensors)
optimizer_total = sum(self._get_tensor_memory(t) for t in optimizer_tensors)
cache_total = sum(self._get_tensor_memory(t) for t in cache_tensors)

# Categorize memory by location (HBM vs UVM)
if self.use_cpu:
weights_hbm, weights_uvm = 0, weights_total
opt_hbm, opt_uvm = 0, optimizer_total
cache_hbm, cache_uvm = 0, cache_total
else:
weights_hbm, weights_uvm = self._categorize_memory_by_location(
weight_tensors
)
opt_hbm, opt_uvm = self._categorize_memory_by_location(optimizer_tensors)
cache_hbm, cache_uvm = self._categorize_memory_by_location(cache_tensors)

# Calculate ephemeral memory split between HBM and UVM
static_sparse_hbm = weights_hbm + opt_hbm + cache_hbm
static_sparse_uvm = weights_uvm + opt_uvm + cache_uvm
ephemeral_hbm = total_hbm_usage - static_sparse_hbm
ephemeral_uvm = total_uvm_usage - static_sparse_uvm

# Report granular memory breakdowns
self._report_hbm_breakdown(
stats_reporter,
weights_hbm,
opt_hbm,
cache_hbm,
static_sparse_hbm,
ephemeral_hbm,
)
self._report_uvm_breakdown(
stats_reporter,
weights_uvm,
opt_uvm,
cache_uvm,
static_sparse_uvm,
ephemeral_uvm,
)

@torch.jit.ignore
def _report_io_size_count(self, event: str, data: Tensor) -> Tensor:
if self.stats_reporter is None:
Expand Down
Loading