Skip to content

Use bounds_check_indices v2 on ROCm #4085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ def bounds_check_indices_abstract(
b_t_map: Optional[torch.Tensor] = None,
info_B_num_bits: int = -1,
info_B_mask: int = -1,
bounds_check_version: int = 1,
prefetch_pipeline: bool = False,
) -> None:
"""
This meta function is used to fake the bounds checking
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,13 @@ def get_new_embedding_location(
# UVM caching
else:
return EmbeddingLocation.MANAGED_CACHING


def get_bounds_check_version_for_platform() -> int:
# NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
# constraint that the gridDim * blockDim has to be smaller than
# 2^32. The v1 kernel can be launched with gridDim * blockDim >
# 2^32 while the v2 kernel limits the gridDim size to 64 * # of
# SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
# than 2^32
return 2 if (torch.cuda.is_available() and torch.version.hip) else 1
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
EmbeddingLocation,
EmbeddingSpecInfo,
get_bounds_check_version_for_platform,
get_new_embedding_location,
MAX_PREFETCH_DEPTH,
PoolingMode,
Expand Down Expand Up @@ -635,6 +636,8 @@ def max_ty_D(ty: SparseType) -> int:
self.fp8_exponent_bits = -1
self.fp8_exponent_bias = -1

self.bounds_check_version: int = get_bounds_check_version_for_platform()

@torch.jit.ignore
def log(self, msg: str) -> None:
"""
Expand Down Expand Up @@ -975,6 +978,7 @@ def _forward_impl(
self.bounds_check_mode_int,
self.bounds_check_warning,
per_sample_weights,
bounds_check_version=self.bounds_check_version,
)

# Index remapping changes input indices, and some of them becomes -1 (prunned rows).
Expand Down Expand Up @@ -1017,6 +1021,7 @@ def _forward_impl(
self.bounds_check_mode_int,
self.bounds_check_warning,
per_sample_weights,
bounds_check_version=self.bounds_check_version,
)
# Note: CPU and CUDA ops use the same interface to facilitate JIT IR
# generation for CUDA/CPU. For CPU op, we don't need weights_uvm and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ComputeDevice,
construct_cache_state,
EmbeddingLocation,
get_bounds_check_version_for_platform,
MAX_PREFETCH_DEPTH,
MultiPassPrefetchConfig,
PoolingMode,
Expand Down Expand Up @@ -692,7 +693,7 @@ def __init__( # noqa C901
self.bounds_check_version: int = (
2
if self._feature_is_enabled(FeatureGateName.BOUNDS_CHECK_INDICES_V2)
else 1
else get_bounds_check_version_for_platform()
)
self.bounds_check_mode_int: int = int(
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
Expand Down Expand Up @@ -735,6 +736,7 @@ def __init__( # noqa C901
# See:
# https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
cache_precision = SparseType.FP32
self.log("Override cache_precision=SparseType.FP32 on ROCm")
else:
# NOTE: The changes from D65865527 are retained here until we can
# test that the the hack also works for non-ROCm environments.
Expand Down
4 changes: 4 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
BoundsCheckMode,
CacheAlgorithm,
EmbeddingLocation,
get_bounds_check_version_for_platform,
KVZCHParams,
PoolingMode,
SplitState,
Expand Down Expand Up @@ -737,6 +738,8 @@ def __init__(
self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name)
self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name)

self.bounds_check_version: int = get_bounds_check_version_for_platform()

@property
# pyre-ignore
def ssd_db(self):
Expand Down Expand Up @@ -2013,6 +2016,7 @@ def prepare_inputs(
per_sample_weights,
B_offsets=vbe_metadata.B_offsets,
max_B=vbe_metadata.max_B,
bounds_check_version=self.bounds_check_version,
)

return indices, offsets, per_sample_weights, vbe_metadata
Expand Down
Loading