diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 97b357772d..bfce26b569 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -832,6 +832,8 @@ def bounds_check_indices_abstract( b_t_map: Optional[torch.Tensor] = None, info_B_num_bits: int = -1, info_B_mask: int = -1, + bounds_check_version: int = 1, + prefetch_pipeline: bool = False, ) -> None: """ This meta function is used to fake the bounds checking diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index b301be3ae4..eb18d0fc09 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -254,3 +254,13 @@ def get_new_embedding_location( # UVM caching else: return EmbeddingLocation.MANAGED_CACHING + + +def get_bounds_check_version_for_platform() -> int: + # NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a + # constraint that the gridDim * blockDim has to be smaller than + # 2^32. The v1 kernel can be launched with gridDim * blockDim > + # 2^32 while the v2 kernel limits the gridDim size to 64 * # of + # SMs. Thus, its gridDim * blockDim is guaranteed to be smaller + # than 2^32 + return 2 if (torch.cuda.is_available() and torch.version.hip) else 1 diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 20f93c1b1b..f0ce9438d6 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -28,6 +28,7 @@ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES, EmbeddingLocation, EmbeddingSpecInfo, + get_bounds_check_version_for_platform, get_new_embedding_location, MAX_PREFETCH_DEPTH, PoolingMode, @@ -635,6 +636,8 @@ def max_ty_D(ty: SparseType) -> int: self.fp8_exponent_bits = -1 self.fp8_exponent_bias = -1 + self.bounds_check_version: int = get_bounds_check_version_for_platform() + @torch.jit.ignore def log(self, msg: str) -> None: """ @@ -975,6 +978,7 @@ def _forward_impl( self.bounds_check_mode_int, self.bounds_check_warning, per_sample_weights, + bounds_check_version=self.bounds_check_version, ) # Index remapping changes input indices, and some of them becomes -1 (prunned rows). @@ -1017,6 +1021,7 @@ def _forward_impl( self.bounds_check_mode_int, self.bounds_check_warning, per_sample_weights, + bounds_check_version=self.bounds_check_version, ) # Note: CPU and CUDA ops use the same interface to facilitate JIT IR # generation for CUDA/CPU. For CPU op, we don't need weights_uvm and diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index f361088852..10dd69f1e7 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -40,6 +40,7 @@ ComputeDevice, construct_cache_state, EmbeddingLocation, + get_bounds_check_version_for_platform, MAX_PREFETCH_DEPTH, MultiPassPrefetchConfig, PoolingMode, @@ -692,7 +693,7 @@ def __init__( # noqa C901 self.bounds_check_version: int = ( 2 if self._feature_is_enabled(FeatureGateName.BOUNDS_CHECK_INDICES_V2) - else 1 + else get_bounds_check_version_for_platform() ) self.bounds_check_mode_int: int = int( os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) @@ -735,6 +736,7 @@ def __init__( # noqa C901 # See: # https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/ cache_precision = SparseType.FP32 + self.log("Override cache_precision=SparseType.FP32 on ROCm") else: # NOTE: The changes from D65865527 are retained here until we can # test that the the hack also works for non-ROCm environments. diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 64e2ef4569..9e3c192391 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -33,6 +33,7 @@ BoundsCheckMode, CacheAlgorithm, EmbeddingLocation, + get_bounds_check_version_for_platform, KVZCHParams, PoolingMode, SplitState, @@ -737,6 +738,8 @@ def __init__( self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name) self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name) + self.bounds_check_version: int = get_bounds_check_version_for_platform() + @property # pyre-ignore def ssd_db(self): @@ -2013,6 +2016,7 @@ def prepare_inputs( per_sample_weights, B_offsets=vbe_metadata.B_offsets, max_B=vbe_metadata.max_B, + bounds_check_version=self.bounds_check_version, ) return indices, offsets, per_sample_weights, vbe_metadata