Skip to content

Add TBE data configuration reporter to TBE forward" #4130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/config/feature_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def foo():
# Enable bounds_check_indices_v2
BOUNDS_CHECK_INDICES_V2 = auto()

# Enable TBE input parameters extraction
TBE_REPORT_INPUT_PARAMS = auto()

def is_enabled(self) -> bool:
return FeatureGate.is_enabled(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
generate_vbe_metadata,
is_torchdynamo_compiling,
)
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
from fbgemm_gpu.tbe_input_multiplexer import (
TBEInfo,
TBEInputInfo,
Expand Down Expand Up @@ -1441,6 +1442,11 @@ def __init__( # noqa C901
self._debug_print_input_stats_factory()
)

# Get a reporter function pointer
self._report_input_params: Callable[..., None] = (
self.__report_input_params_factory()
)

if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
# Register writeback hook for Exact_SGD optimizer
self.log(
Expand Down Expand Up @@ -1952,6 +1958,18 @@ def forward( # noqa: C901
# Print input stats if enable (for debugging purpose only)
self._debug_print_input_stats(indices, offsets, per_sample_weights)

# Extract and Write input stats if enable
self._report_input_params(
feature_rows=self.rows_per_table,
feature_dims=self.feature_dims,
iteration=self.iter.item() if hasattr(self, "iter") else 0,
indices=indices,
offsets=offsets,
op_id=self.uuid,
per_sample_weights=per_sample_weights,
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
)

if not is_torchdynamo_compiling():
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time

Expand Down Expand Up @@ -3792,6 +3810,36 @@ def _debug_print_input_stats_factory_null(
return _debug_print_input_stats_factory_impl
return _debug_print_input_stats_factory_null

@torch.jit.ignore
def __report_input_params_factory(self) -> Callable[..., None]:
"""
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.

If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
- Reports input parameters (TBEDataConfig).
- Writes the output as a JSON file.

If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
"""

@torch.jit.ignore
def __report_input_params_factory_null(
feature_rows: Tensor,
feature_dims: Tensor,
iteration: int,
indices: Tensor,
offsets: Tensor,
op_id: Optional[str] = None,
per_sample_weights: Optional[Tensor] = None,
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
) -> None:
pass

if FeatureGateName.TBE_REPORT_INPUT_PARAMS.is_enabled():
reporter = TBEBenchmarkParamsReporter.create()
return reporter.report_stats
return __report_input_params_factory_null


class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Expand Down
8 changes: 6 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import torch
import yaml

from .tbe_data_config import TBEDataConfig
from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
from fbgemm_gpu.tbe.bench.tbe_data_config import (
BatchParams,
IndicesParams,
PoolingParams,
TBEDataConfig,
)


class TBEDataConfigLoader:
Expand Down
150 changes: 116 additions & 34 deletions fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import io
import json
import logging
import os
from typing import List, Optional
Expand All @@ -16,18 +17,20 @@
import numpy as np # usort:skip
import torch # usort:skip

from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.bench import (
from fbgemm_gpu.tbe.bench.tbe_data_config import (
BatchParams,
IndicesParams,
PoolingParams,
TBEDataConfig,
)

# pyre-ignore[16]
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
open_source: bool = False
try:
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
if getattr(fbgemm_gpu, "open_source", False):
open_source = True
except Exception:
pass

if open_source:
from fbgemm_gpu.utils import FileStore
Expand All @@ -43,7 +46,8 @@ class TBEBenchmarkParamsReporter:
def __init__(
self,
report_interval: int,
report_once: bool = False,
report_iter_start: int = 0,
report_iter_end: int = -1,
bucket: Optional[str] = None,
path_prefix: Optional[str] = None,
) -> None:
Expand All @@ -52,13 +56,26 @@ def __init__(

Args:
report_interval (int): The interval at which reports are generated.
report_once (bool, optional): If True, reporting occurs only once. Defaults to False.
report_iter_start (int): The start of the iteration range to capture. Defaults to 0.
report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration).
bucket (Optional[str], optional): The storage bucket for reports. Defaults to None.
path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None.
"""
assert report_interval > 0, "report_interval must be greater than 0"
assert (
report_iter_start >= 0
), "report_iter_start must be greater than or equal to 0"
assert (
report_iter_end >= -1
), "report_iter_end must be greater than or equal to -1"
assert (
report_iter_end == -1 or report_iter_start <= report_iter_end
), "report_iter_start must be less than or equal to report_iter_end"

self.report_interval = report_interval
self.report_once = report_once
self.has_reported = False
self.report_iter_start = report_iter_start
self.report_iter_end = report_iter_end
self.path_prefix = path_prefix

default_bucket = "/tmp" if open_source else "tlparse_reports"
bucket = (
Expand All @@ -71,19 +88,59 @@ def __init__(
self.logger: logging.Logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)

@classmethod
def create(cls) -> "TBEBenchmarkParamsReporter":
"""
This method returns an instance of TBEBenchmarkParamsReporter based on environment variables.

If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that:
- Reports input parameters (TBEDataConfig).
- Writes the output as a JSON file.

Additionally, the following environment variables are considered:
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture.
- `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture.
- `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting.
- `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting.

Returns:
TBEBenchmarkParamsReporter: An instance configured based on the environment variables.
"""
report_interval = int(
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1")
)
report_iter_start = int(
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0")
)
report_iter_end = int(
os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1")
)
bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "")
path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "")

return cls(
report_interval=report_interval,
report_iter_start=report_iter_start,
report_iter_end=report_iter_end,
bucket=bucket,
path_prefix=path_prefix,
)

def extract_params(
self,
embedding_op: SplitTableBatchedEmbeddingBagsCodegen,
feature_rows: torch.Tensor,
feature_dims: torch.Tensor,
indices: torch.Tensor,
offsets: torch.Tensor,
per_sample_weights: Optional[torch.Tensor] = None,
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
) -> TBEDataConfig:
"""
Extracts parameters from the embedding operation, input indices and offsets to create a TBEDataConfig.
Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig.

Args:
embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation.
feature_rows (torch.Tensor): Number of rows in each feature.
feature_dims (torch.Tensor): Number of dimensions in each feature.
indices (torch.Tensor): The input indices tensor.
offsets (torch.Tensor): The input offsets tensor.
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
Expand All @@ -92,24 +149,25 @@ def extract_params(
Returns:
TBEDataConfig: The configuration data for TBE benchmarking.
"""

Es = feature_rows.tolist()
Ds = feature_dims.tolist()

assert len(Es) == len(
Ds
), "feature_rows and feature_dims must have the same length"

# Transfer indices back to CPU for EEG analysis
indices_cpu = indices.cpu()

# Extract embedding table specs
embedding_specs = [
embedding_op.embedding_specs[t] for t in embedding_op.feature_table_map
]
rowcounts = [embedding_spec[0] for embedding_spec in embedding_specs]
dims = [embedding_spec[1] for embedding_spec in embedding_specs]

# Set T to be the number of features we are looking at
T = len(embedding_op.feature_table_map)
T = len(Ds)
# Set E to be the mean of the rowcounts to avoid biasing
E = rowcounts[0] if len(set(rowcounts)) == 1 else np.ceil((np.mean(rowcounts)))
E = Es[0] if len(set(Es)) == 1 else np.ceil((np.mean(Es)))
# Set mixed_dim to be True if there are multiple dims
mixed_dim = len(set(dims)) > 1
mixed_dim = len(set(Ds)) > 1
# Set D to be the mean of the dims to avoid biasing
D = dims[0] if not mixed_dim else np.ceil((np.mean(dims)))
D = Ds[0] if not mixed_dim else np.ceil((np.mean(Ds)))

# Compute indices distribution parameters
heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
Expand Down Expand Up @@ -160,34 +218,58 @@ def extract_params(

def report_stats(
self,
embedding_op: SplitTableBatchedEmbeddingBagsCodegen,
feature_rows: torch.Tensor,
feature_dims: torch.Tensor,
iteration: int,
indices: torch.Tensor,
offsets: torch.Tensor,
op_id: str = "",
per_sample_weights: Optional[torch.Tensor] = None,
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
) -> None:
"""
Reports the configuration of the embedding operation and input data then writes the TBE configuration to the filestore.
Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore.

Args:
embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation.
feature_rows (torch.Tensor): Number of rows in each feature.
feature_dims (torch.Tensor): Number of dimensions in each feature.
iteration (int): The current iteration number.
indices (torch.Tensor): The input indices tensor.
offsets (torch.Tensor): The input offsets tensor.
op_id (str, optional): The operation identifier. Defaults to an empty string.
per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
"""
if embedding_op.iter.item() % self.report_interval == 0 and (
not self.report_once or (self.report_once and not self.has_reported)
if (
(iteration - self.report_iter_start) % self.report_interval == 0
and (iteration >= self.report_iter_start)
and (self.report_iter_end == -1 or iteration <= self.report_iter_end)
):
# Extract TBE config
config = self.extract_params(
embedding_op, indices, offsets, per_sample_weights
feature_rows=feature_rows,
feature_dims=feature_dims,
indices=indices,
offsets=offsets,
per_sample_weights=per_sample_weights,
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
)

config.json()

# Ad-hoc fix for adding Es and Ds to JSON output
# TODO: Remove this once we moved Es and Ds to be part of TBEDataConfig
adhoc_config = config.dict()
adhoc_config["Es"] = feature_rows.tolist()
adhoc_config["Ds"] = feature_dims.tolist()
if batch_size_per_feature_per_rank:
adhoc_config["Bs"] = [
sum(batch_size_per_feature_per_rank[f])
for f in range(len(adhoc_config["Es"]))
]

# Write the TBE config to FileStore
self.filestore.write(
f"tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json",
io.BytesIO(config.json(format=True).encode()),
f"{self.path_prefix}tbe-{op_id}-config-estimation-{iteration}.json",
io.BytesIO(json.dumps(adhoc_config, indent=2).encode()),
)

self.has_reported = True
9 changes: 2 additions & 7 deletions fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import numpy.typing as npt
import torch

# pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed).
from numpy.random import default_rng

from .common import get_device
from .offsets import get_table_batched_offsets_from_dense

Expand Down Expand Up @@ -309,11 +306,9 @@ def generate_indices_zipf(
indices, torch.tensor([0, L], dtype=torch.long), True
)
if deterministic_output:
rng = default_rng(12345)
else:
rng = default_rng()
np.random.seed(12345)
permutation = torch.as_tensor(
rng.choice(E, size=indices.max().item() + 1, replace=False)
np.random.choice(E, size=indices.max().item() + 1, replace=False)
)
indices = permutation.gather(0, indices.flatten())
indices = indices.to(get_device()).int()
Expand Down
8 changes: 6 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/utils/filestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import io
import logging
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import BinaryIO, Union
Expand Down Expand Up @@ -76,7 +75,12 @@ def write(
elif isinstance(raw_input, Path):
if not os.path.exists(raw_input):
raise FileNotFoundError(f"File {raw_input} does not exist")
shutil.copyfile(raw_input, filepath)
# Open the source file and destination file, and copy the contents
with open(raw_input, "rb") as src_file, open(
filepath, "wb"
) as dst_file:
while chunk := src_file.read(4096): # Read 4 KB at a time
dst_file.write(chunk)

elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO):
with open(filepath, "wb") as file:
Expand Down
Loading
Loading