From 1fa04932e18836d2d9f440d67e3573da8918922e Mon Sep 17 00:00:00 2001 From: Gantaphon Chalumporn Date: Fri, 23 May 2025 15:23:23 -0700 Subject: [PATCH] Add TBE data configuration reporter to TBE forward" (#4130) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1211 # Add TBE data configuration reporter to TBE forward call. The reporter reports TBE data configuration at the `SplitTableBatchedEmbeddingBagsCodegen` ***forward*** call. The output is a `TBEDataConfig` object, which is written to a JSON file(s). The configuration of its environment variables and an example of its usage is described below. ## Just Knobs for enablement - fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS is added for enablement of the reporter (https://www.internalfb.com/intern/justknobs/?name=fbgemm_gpu%2Ffeatures) - Default is set to `False`, enable this flag to enable reporter. - To enable it locally use: ``` jk canary set fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS --on --ttl 600 ``` ## Environment Variables --------------------- The Reporter relies on several environment variables to control its behavior. Below is a description of each variable: - **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL**: - **Description**: Determines the interval at which reports are generated. This is specified in terms of the number of iterations. - **Example Value**: `1` (report every iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_START**: - ***Description**: Specifies the start of the iteration range to capture reports. Default 0. - ***Example Value**: `0` (start reporting from the first iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_END**: - ***Description**: Specifies the end of the iteration range to capture reports. Use `-1` to report until the last iteration. Default -1. - ***Example Value**: `-1` (report until the last iteration) - **FBGEMM_REPORT_INPUT_PARAMS_BUCKET**: * **Description**: Specifies the name of the Manifold bucket where the report data will be saved. * **Example Value**: `tlparse_reports` - **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX**: - **Description**: Defines the path prefix where the report files will be stored. - **Example Value**: `tree/tests/` ## Example Usage ------------- Below is an example command demonstrating how to use the FBGEMM Reporter with specific environment variable settings: ``` FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2 FBGEMM_REPORT_INPUT_PARAMS_ITER_START=3 FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/ buck2 run mode/opt //deeplearning/fbgemm/fbgemm_gpu/bench:split_table_batched_embeddings -- device --iters 2 ``` **Explanation** The above setting will report `iter 3` and `iter 5` * **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2**: The reporter will generate a report every 2 iterations. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_START=0**: The reporter will start generating reports from the first iteration. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_END=-1 (Default)**: The reporter will continue to generate reports until the last iteration interval. * **FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports**: The reports will be saved in the `tlparse_reports` bucket. * **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/**: The reports will be stored with the path prefix `tree/tests/`. For Manifold make sure all folders within the path exist. **Note on Benchmark example** Note that with the `--iters 2` option, the benchmark will execute 6 forward calls (2 iterations plus 1 warmup) for the forward benchmark and another 3 calls (2 iterations plus 1 warmup) for the backward benchmark. Iteration starts from 0. --- --- ## Other includes changes in this Diff: - Updates build dependency of tbe_data_config* files - Remove `shutil` and `numpy.random` lib as it cause uncompatiblity error. - Add non-OSS test, writing extracted config data json file to Manifold Reviewed By: q10 Differential Revision: D73927918 --- fbgemm_gpu/fbgemm_gpu/config/feature_list.py | 3 + ...t_table_batched_embeddings_ops_training.py | 48 ++++++ .../tbe/bench/tbe_data_config_loader.py | 8 +- .../tbe/stats/bench_params_reporter.py | 150 +++++++++++++---- fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py | 9 +- fbgemm_gpu/fbgemm_gpu/utils/filestore.py | 8 +- .../include/fbgemm_gpu/config/feature_gates.h | 3 +- .../stats/tbe_bench_params_reporter_test.py | 156 ++++++++++++++++-- fbgemm_gpu/test/utils/filestore_test.py | 4 + 9 files changed, 325 insertions(+), 64 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py index 14b2bbe2a9..8f8fd6f495 100644 --- a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py +++ b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py @@ -60,6 +60,9 @@ def foo(): # Enable bounds_check_indices_v2 BOUNDS_CHECK_INDICES_V2 = auto() + # Enable TBE input parameters extraction + TBE_REPORT_INPUT_PARAMS = auto() + def is_enabled(self) -> bool: return FeatureGate.is_enabled(self) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 4c06501815..e9765984c7 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -51,6 +51,7 @@ generate_vbe_metadata, is_torchdynamo_compiling, ) +from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter from fbgemm_gpu.tbe_input_multiplexer import ( TBEInfo, TBEInputInfo, @@ -1441,6 +1442,11 @@ def __init__( # noqa C901 self._debug_print_input_stats_factory() ) + # Get a reporter function pointer + self._report_input_params: Callable[..., None] = ( + self.__report_input_params_factory() + ) + if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook: # Register writeback hook for Exact_SGD optimizer self.log( @@ -1952,6 +1958,18 @@ def forward( # noqa: C901 # Print input stats if enable (for debugging purpose only) self._debug_print_input_stats(indices, offsets, per_sample_weights) + # Extract and Write input stats if enable + self._report_input_params( + feature_rows=self.rows_per_table, + feature_dims=self.feature_dims, + iteration=self.iter.item() if hasattr(self, "iter") else 0, + indices=indices, + offsets=offsets, + op_id=self.uuid, + per_sample_weights=per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + if not is_torchdynamo_compiling(): # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time @@ -3792,6 +3810,36 @@ def _debug_print_input_stats_factory_null( return _debug_print_input_stats_factory_impl return _debug_print_input_stats_factory_null + @torch.jit.ignore + def __report_input_params_factory(self) -> Callable[..., None]: + """ + This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`. + + If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that: + - Reports input parameters (TBEDataConfig). + - Writes the output as a JSON file. + + If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action. + """ + + @torch.jit.ignore + def __report_input_params_factory_null( + feature_rows: Tensor, + feature_dims: Tensor, + iteration: int, + indices: Tensor, + offsets: Tensor, + op_id: Optional[str] = None, + per_sample_weights: Optional[Tensor] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, + ) -> None: + pass + + if FeatureGateName.TBE_REPORT_INPUT_PARAMS.is_enabled(): + reporter = TBEBenchmarkParamsReporter.create() + return reporter.report_stats + return __report_input_params_factory_null + class DenseTableBatchedEmbeddingBagsCodegen(nn.Module): """ diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py index e27a5dec0c..96e26554c0 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py @@ -11,8 +11,12 @@ import torch import yaml -from .tbe_data_config import TBEDataConfig -from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams +from fbgemm_gpu.tbe.bench.tbe_data_config import ( + BatchParams, + IndicesParams, + PoolingParams, + TBEDataConfig, +) class TBEDataConfigLoader: diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py index 794b38a20e..226347de29 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py @@ -8,6 +8,7 @@ # pyre-strict import io +import json import logging import os from typing import List, Optional @@ -16,18 +17,20 @@ import numpy as np # usort:skip import torch # usort:skip -from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( - SplitTableBatchedEmbeddingBagsCodegen, -) -from fbgemm_gpu.tbe.bench import ( +from fbgemm_gpu.tbe.bench.tbe_data_config import ( BatchParams, IndicesParams, PoolingParams, TBEDataConfig, ) -# pyre-ignore[16] -open_source: bool = getattr(fbgemm_gpu, "open_source", False) +open_source: bool = False +try: + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. + if getattr(fbgemm_gpu, "open_source", False): + open_source = True +except Exception: + pass if open_source: from fbgemm_gpu.utils import FileStore @@ -43,7 +46,8 @@ class TBEBenchmarkParamsReporter: def __init__( self, report_interval: int, - report_once: bool = False, + report_iter_start: int = 0, + report_iter_end: int = -1, bucket: Optional[str] = None, path_prefix: Optional[str] = None, ) -> None: @@ -52,13 +56,26 @@ def __init__( Args: report_interval (int): The interval at which reports are generated. - report_once (bool, optional): If True, reporting occurs only once. Defaults to False. + report_iter_start (int): The start of the iteration range to capture. Defaults to 0. + report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration). bucket (Optional[str], optional): The storage bucket for reports. Defaults to None. path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None. """ + assert report_interval > 0, "report_interval must be greater than 0" + assert ( + report_iter_start >= 0 + ), "report_iter_start must be greater than or equal to 0" + assert ( + report_iter_end >= -1 + ), "report_iter_end must be greater than or equal to -1" + assert ( + report_iter_end == -1 or report_iter_start <= report_iter_end + ), "report_iter_start must be less than or equal to report_iter_end" + self.report_interval = report_interval - self.report_once = report_once - self.has_reported = False + self.report_iter_start = report_iter_start + self.report_iter_end = report_iter_end + self.path_prefix = path_prefix default_bucket = "/tmp" if open_source else "tlparse_reports" bucket = ( @@ -71,19 +88,59 @@ def __init__( self.logger: logging.Logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) + @classmethod + def create(cls) -> "TBEBenchmarkParamsReporter": + """ + This method returns an instance of TBEBenchmarkParamsReporter based on environment variables. + + If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that: + - Reports input parameters (TBEDataConfig). + - Writes the output as a JSON file. + + Additionally, the following environment variables are considered: + - `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture. + - `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture. + - `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting. + - `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting. + + Returns: + TBEBenchmarkParamsReporter: An instance configured based on the environment variables. + """ + report_interval = int( + os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1") + ) + report_iter_start = int( + os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0") + ) + report_iter_end = int( + os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1") + ) + bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "") + path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "") + + return cls( + report_interval=report_interval, + report_iter_start=report_iter_start, + report_iter_end=report_iter_end, + bucket=bucket, + path_prefix=path_prefix, + ) + def extract_params( self, - embedding_op: SplitTableBatchedEmbeddingBagsCodegen, + feature_rows: torch.Tensor, + feature_dims: torch.Tensor, indices: torch.Tensor, offsets: torch.Tensor, per_sample_weights: Optional[torch.Tensor] = None, batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> TBEDataConfig: """ - Extracts parameters from the embedding operation, input indices and offsets to create a TBEDataConfig. + Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig. Args: - embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation. + feature_rows (torch.Tensor): Number of rows in each feature. + feature_dims (torch.Tensor): Number of dimensions in each feature. indices (torch.Tensor): The input indices tensor. offsets (torch.Tensor): The input offsets tensor. per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None. @@ -92,24 +149,25 @@ def extract_params( Returns: TBEDataConfig: The configuration data for TBE benchmarking. """ + + Es = feature_rows.tolist() + Ds = feature_dims.tolist() + + assert len(Es) == len( + Ds + ), "feature_rows and feature_dims must have the same length" + # Transfer indices back to CPU for EEG analysis indices_cpu = indices.cpu() - # Extract embedding table specs - embedding_specs = [ - embedding_op.embedding_specs[t] for t in embedding_op.feature_table_map - ] - rowcounts = [embedding_spec[0] for embedding_spec in embedding_specs] - dims = [embedding_spec[1] for embedding_spec in embedding_specs] - # Set T to be the number of features we are looking at - T = len(embedding_op.feature_table_map) + T = len(Ds) # Set E to be the mean of the rowcounts to avoid biasing - E = rowcounts[0] if len(set(rowcounts)) == 1 else np.ceil((np.mean(rowcounts))) + E = Es[0] if len(set(Es)) == 1 else np.ceil((np.mean(Es))) # Set mixed_dim to be True if there are multiple dims - mixed_dim = len(set(dims)) > 1 + mixed_dim = len(set(Ds)) > 1 # Set D to be the mean of the dims to avoid biasing - D = dims[0] if not mixed_dim else np.ceil((np.mean(dims))) + D = Ds[0] if not mixed_dim else np.ceil((np.mean(Ds))) # Compute indices distribution parameters heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution( @@ -160,34 +218,58 @@ def extract_params( def report_stats( self, - embedding_op: SplitTableBatchedEmbeddingBagsCodegen, + feature_rows: torch.Tensor, + feature_dims: torch.Tensor, + iteration: int, indices: torch.Tensor, offsets: torch.Tensor, + op_id: str = "", per_sample_weights: Optional[torch.Tensor] = None, batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> None: """ - Reports the configuration of the embedding operation and input data then writes the TBE configuration to the filestore. + Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore. Args: - embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation. + feature_rows (torch.Tensor): Number of rows in each feature. + feature_dims (torch.Tensor): Number of dimensions in each feature. + iteration (int): The current iteration number. indices (torch.Tensor): The input indices tensor. offsets (torch.Tensor): The input offsets tensor. + op_id (str, optional): The operation identifier. Defaults to an empty string. per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None. batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None. """ - if embedding_op.iter.item() % self.report_interval == 0 and ( - not self.report_once or (self.report_once and not self.has_reported) + if ( + (iteration - self.report_iter_start) % self.report_interval == 0 + and (iteration >= self.report_iter_start) + and (self.report_iter_end == -1 or iteration <= self.report_iter_end) ): # Extract TBE config config = self.extract_params( - embedding_op, indices, offsets, per_sample_weights + feature_rows=feature_rows, + feature_dims=feature_dims, + indices=indices, + offsets=offsets, + per_sample_weights=per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, ) + config.json() + + # Ad-hoc fix for adding Es and Ds to JSON output + # TODO: Remove this once we moved Es and Ds to be part of TBEDataConfig + adhoc_config = config.dict() + adhoc_config["Es"] = feature_rows.tolist() + adhoc_config["Ds"] = feature_dims.tolist() + if batch_size_per_feature_per_rank: + adhoc_config["Bs"] = [ + sum(batch_size_per_feature_per_rank[f]) + for f in range(len(adhoc_config["Es"])) + ] + # Write the TBE config to FileStore self.filestore.write( - f"tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json", - io.BytesIO(config.json(format=True).encode()), + f"{self.path_prefix}tbe-{op_id}-config-estimation-{iteration}.json", + io.BytesIO(json.dumps(adhoc_config, indent=2).encode()), ) - - self.has_reported = True diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py index bd64223a09..9eded1ce5d 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py @@ -14,9 +14,6 @@ import numpy.typing as npt import torch -# pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed). -from numpy.random import default_rng - from .common import get_device from .offsets import get_table_batched_offsets_from_dense @@ -309,11 +306,9 @@ def generate_indices_zipf( indices, torch.tensor([0, L], dtype=torch.long), True ) if deterministic_output: - rng = default_rng(12345) - else: - rng = default_rng() + np.random.seed(12345) permutation = torch.as_tensor( - rng.choice(E, size=indices.max().item() + 1, replace=False) + np.random.choice(E, size=indices.max().item() + 1, replace=False) ) indices = permutation.gather(0, indices.flatten()) indices = indices.to(get_device()).int() diff --git a/fbgemm_gpu/fbgemm_gpu/utils/filestore.py b/fbgemm_gpu/fbgemm_gpu/utils/filestore.py index 9261f85922..293be5f925 100644 --- a/fbgemm_gpu/fbgemm_gpu/utils/filestore.py +++ b/fbgemm_gpu/fbgemm_gpu/utils/filestore.py @@ -11,7 +11,6 @@ import io import logging import os -import shutil from dataclasses import dataclass from pathlib import Path from typing import BinaryIO, Union @@ -76,7 +75,12 @@ def write( elif isinstance(raw_input, Path): if not os.path.exists(raw_input): raise FileNotFoundError(f"File {raw_input} does not exist") - shutil.copyfile(raw_input, filepath) + # Open the source file and destination file, and copy the contents + with open(raw_input, "rb") as src_file, open( + filepath, "wb" + ) as dst_file: + while chunk := src_file.read(4096): # Read 4 KB at a time + dst_file.write(chunk) elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO): with open(filepath, "wb") as file: diff --git a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h index 11c4d55763..9018e68603 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h +++ b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h @@ -61,7 +61,8 @@ namespace fbgemm_gpu::config { X(TBE_ANNOTATE_KINETO_TRACE) \ X(TBE_ROCM_INFERENCE_PACKED_BAGS) \ X(TBE_ROCM_HIP_BACKWARD_KERNEL) \ - X(BOUNDS_CHECK_INDICES_V2) + X(BOUNDS_CHECK_INDICES_V2) \ + X(TBE_REPORT_INPUT_PARAMS) // X(EXAMPLE_FEATURE_FLAG) /// @ingroup fbgemm-gpu-config diff --git a/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py b/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py index 112a174091..9fec3cd693 100644 --- a/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py +++ b/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py @@ -8,13 +8,20 @@ # pyre-strict import unittest +from typing import Optional +from unittest.mock import patch + +import fbgemm_gpu import hypothesis.strategies as st import torch -from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation -from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( +from fbgemm_gpu.config import FeatureGateName +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( ComputeDevice, + EmbeddingLocation, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( SplitTableBatchedEmbeddingBagsCodegen, ) from fbgemm_gpu.tbe.bench import ( @@ -27,6 +34,15 @@ from fbgemm_gpu.tbe.utils import get_device from hypothesis import given, settings +from .. import common # noqa E402 +from ..common import running_in_oss + +try: + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. + open_source: bool = getattr(fbgemm_gpu, "open_source", False) +except Exception: + open_source: bool = False + class TestTBEBenchmarkParamsReporter(unittest.TestCase): # pyre-ignore[56] @@ -76,21 +92,23 @@ def test_report_stats( # Generate the embedding dimension list _, Ds = tbeconfig.generate_embedding_dims() + embedding_specs = [ + ( + tbeconfig.E, + D, + embedding_location, + ( + ComputeDevice.CUDA + if torch.cuda.is_available() + else ComputeDevice.CPU + ), + ) + for D in Ds + ] + # Generate the embedding operation embedding_op = SplitTableBatchedEmbeddingBagsCodegen( - [ - ( - tbeconfig.E, - D, - embedding_location, - ( - ComputeDevice.CUDA - if torch.cuda.is_available() - else ComputeDevice.CPU - ), - ) - for D in Ds - ], + embedding_specs, embedding_table_index_type=tbeconfig.indices_params.index_dtype or torch.int64, embedding_table_offset_type=tbeconfig.indices_params.offset_dtype @@ -105,9 +123,10 @@ def test_report_stats( # Generate indices and offsets request = tbeconfig.generate_requests(1)[0] - # Call the report_stats method + # Call the extract_params method extracted_config = reporter.extract_params( - embedding_op=embedding_op, + feature_rows=embedding_op.rows_per_table, + feature_dims=embedding_op.feature_dims, indices=request.indices, offsets=request.offsets, ) @@ -125,4 +144,105 @@ def test_report_stats( and extracted_config.indices_params.offset_dtype == tbeconfig.indices_params.offset_dtype ), "Extracted config does not match the original TBEDataConfig" - # Attempt to reconstruct TBEDataConfig from extracted_json_config + + # pyre-ignore[56] + @given( + T=st.integers(1, 10), + E=st.integers(100, 10000), + D=st.sampled_from([32, 64, 128, 256]), + L=st.integers(1, 10), + B=st.integers(20, 100), + ) + @settings(max_examples=1, deadline=None) + @unittest.skipIf(*running_in_oss) + def test_report_fb_files( + self, + T: int, + E: int, + D: int, + L: int, + B: int, + ) -> None: + """ + Test writing extrcted TBEDataConfig to FB FileStore + """ + from fbgemm_gpu.fb.utils import FileStore + + # Initialize the reporter + bucket = "tlparse_reports" + path_prefix = "tree/unit_tests/" + + # Generate a TBEDataConfig + tbeconfig = TBEDataConfig( + T=T, + E=E, + D=D, + mixed_dim=False, + weighted=False, + batch_params=BatchParams(B=B), + indices_params=IndicesParams( + heavy_hitters=torch.tensor([]), + zipf_q=0.1, + zipf_s=0.1, + index_dtype=torch.int64, + offset_dtype=torch.int64, + ), + pooling_params=PoolingParams(L=L), + use_cpu=not torch.cuda.is_available(), + ) + + embedding_location = ( + EmbeddingLocation.DEVICE + if torch.cuda.is_available() + else EmbeddingLocation.HOST + ) + + # Generate the embedding dimension list + _, Ds = tbeconfig.generate_embedding_dims() + + with patch( + "torch.ops.fbgemm.check_feature_gate_key" + ) as mock_check_feature_gate_key: + # Mock the return value for TBE_REPORT_INPUT_PARAMS + def side_effect(feature_name: str) -> Optional[bool]: + if feature_name == FeatureGateName.TBE_REPORT_INPUT_PARAMS.name: + return True + + mock_check_feature_gate_key.side_effect = side_effect + + # Generate the embedding operation + embedding_op = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + tbeconfig.E, + D, + embedding_location, + ( + ComputeDevice.CUDA + if torch.cuda.is_available() + else ComputeDevice.CPU + ), + ) + for D in Ds + ], + ) + + embedding_op = embedding_op.to(get_device()) + + # Generate indices and offsets + request = tbeconfig.generate_requests(1)[0] + + # Execute the embedding operation with reporting flag enable + embedding_op.forward(request.indices, request.offsets) + + # Check if the file was written to Manifold + store = FileStore(bucket) + path = f"{path_prefix}tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json" + assert store.exists(path), f"{path} not exists" + + # Clenaup, delete the file + store.remove(path) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/utils/filestore_test.py b/fbgemm_gpu/test/utils/filestore_test.py index 38877e6785..0ee1cb930e 100644 --- a/fbgemm_gpu/test/utils/filestore_test.py +++ b/fbgemm_gpu/test/utils/filestore_test.py @@ -157,3 +157,7 @@ def test_filestore_fb_file(self) -> None: Path(infile.name), f"tree/{''.join(random.choices(string.ascii_letters, k=15))}.unittest", ) + + +if __name__ == "__main__": + unittest.main()