Skip to content

Commit f7c74e4

Browse files
emlinmeta-codesync[bot]
authored andcommitted
enable feature score auto collection in EBC (#3475)
Summary: X-link: pytorch/FBGEMM#5031 Pull Request resolved: #3475 X-link: https://github.com/facebookresearch/FBGEMM/pull/2044 Enable feature score auto collection for EBC in the similar way of EC. The configuration has no difference in embedding table config: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Reviewed By: EddyLXJ Differential Revision: D85017179 fbshipit-source-id: 3d62f8adbe201d6e30c445aaed88710bbbcd6557
1 parent 7c7daaf commit f7c74e4

File tree

6 files changed

+339
-8
lines changed

6 files changed

+339
-8
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def _populate_zero_collision_tbe_params(
326326
meta_header_lens[i] = table.virtual_table_eviction_policy.get_meta_header_len()
327327
if not isinstance(table.virtual_table_eviction_policy, NoEvictionPolicy):
328328
enabled = True
329+
330+
fs_eviction_enabled: bool = False
329331
if enabled:
330332
counter_thresholds = [0] * len(config.embedding_tables)
331333
ttls_in_mins = [0] * len(config.embedding_tables)
@@ -384,6 +386,7 @@ def _populate_zero_collision_tbe_params(
384386
raise ValueError(
385387
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 5 for tables {table_names}"
386388
)
389+
fs_eviction_enabled = True
387390
elif isinstance(policy_t, TimestampBasedEvictionPolicy):
388391
training_id_eviction_trigger_count[i] = (
389392
policy_t.training_id_eviction_trigger_count
@@ -459,6 +462,7 @@ def _populate_zero_collision_tbe_params(
459462
backend_return_whole_row=(backend_type == BackendType.DRAM),
460463
eviction_policy=eviction_policy,
461464
embedding_cache_mode=embedding_cache_mode_,
465+
feature_score_collection_enabled=fs_eviction_enabled,
462466
)
463467

464468

@@ -2901,6 +2905,7 @@ def __init__(
29012905
_populate_zero_collision_tbe_params(
29022906
ssd_tbe_params, self._bucket_spec, config, backend_type
29032907
)
2908+
self._kv_zch_params: KVZCHParams = ssd_tbe_params["kv_zch_params"]
29042909
compute_kernel = config.embedding_tables[0].compute_kernel
29052910
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
29062911

@@ -3185,7 +3190,40 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
31853190
self._split_weights_res = None
31863191
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
31873192

3188-
return super().forward(features)
3193+
weights = features.weights_or_none()
3194+
per_sample_weights = None
3195+
score_weights = None
3196+
if weights is not None and weights.dtype == torch.float64:
3197+
fp32_weights = weights.view(torch.float32)
3198+
per_sample_weights = fp32_weights[:, 0]
3199+
score_weights = fp32_weights[:, 1]
3200+
elif weights is not None and weights.dtype == torch.float32:
3201+
if self._kv_zch_params.feature_score_collection_enabled:
3202+
score_weights = weights.view(-1)
3203+
else:
3204+
per_sample_weights = weights.view(-1)
3205+
if features.variable_stride_per_key() and isinstance(
3206+
self.emb_module,
3207+
(
3208+
SplitTableBatchedEmbeddingBagsCodegen,
3209+
DenseTableBatchedEmbeddingBagsCodegen,
3210+
SSDTableBatchedEmbeddingBags,
3211+
),
3212+
):
3213+
return self.emb_module(
3214+
indices=features.values().long(),
3215+
offsets=features.offsets().long(),
3216+
weights=score_weights,
3217+
per_sample_weights=per_sample_weights,
3218+
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
3219+
)
3220+
else:
3221+
return self.emb_module(
3222+
indices=features.values().long(),
3223+
offsets=features.offsets().long(),
3224+
weights=score_weights,
3225+
per_sample_weights=per_sample_weights,
3226+
)
31893227

31903228

31913229
class BatchedFusedEmbeddingBag(

torchrec/distributed/embedding_lookup.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
QuantBatchedEmbeddingBag,
6767
)
6868
from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType
69+
from torchrec.modules.embedding_configs import FeatureScoreBasedEvictionPolicy
6970
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
7071

7172
logger: logging.Logger = logging.getLogger(__name__)
@@ -515,6 +516,23 @@ def __init__(
515516
) -> None:
516517
super().__init__()
517518
self._emb_modules: nn.ModuleList = nn.ModuleList()
519+
self._feature_score_auto_collections: List[bool] = []
520+
for config in grouped_configs:
521+
collection = False
522+
for table in config.embedding_tables:
523+
if table.use_virtual_table and isinstance(
524+
table.virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
525+
):
526+
if (
527+
table.virtual_table_eviction_policy.enable_auto_feature_score_collection
528+
):
529+
collection = True
530+
self._feature_score_auto_collections.append(collection)
531+
532+
logger.info(
533+
f"GroupedPooledEmbeddingsLookup: {self._feature_score_auto_collections=}"
534+
)
535+
518536
for config in grouped_configs:
519537
self._emb_modules.append(
520538
self._create_embedding_kernel(config, device, pg, sharding_type)
@@ -692,8 +710,11 @@ def forward(
692710
features_by_group = sparse_features.split(
693711
self._feature_splits,
694712
)
695-
for config, emb_op, features in zip(
696-
self.grouped_configs, self._emb_modules, features_by_group
713+
for config, emb_op, features, fs_auto_collection in zip(
714+
self.grouped_configs,
715+
self._emb_modules,
716+
features_by_group,
717+
self._feature_score_auto_collections,
697718
):
698719
if (
699720
config.has_feature_processor
@@ -703,9 +724,19 @@ def forward(
703724
features = self._feature_processor(features)
704725

705726
if config.is_weighted:
706-
features._weights = CommOpGradientScaling.apply(
727+
feature_weights = CommOpGradientScaling.apply(
707728
features._weights, self._scale_gradient_factor
708-
)
729+
).float()
730+
731+
if fs_auto_collection and features.weights_or_none() is not None:
732+
score_weights = features.weights().float()
733+
assert (
734+
feature_weights.numel() == score_weights.numel()
735+
), f"feature_weights.numel() {feature_weights.numel()} != score_weights.numel() {score_weights.numel()}"
736+
cat_weights = torch.cat(
737+
[feature_weights, score_weights], dim=1
738+
).view(torch.float64)
739+
features._weights = cat_weights
709740

710741
lookup = emb_op(features)
711742
embeddings.append(lookup)

torchrec/distributed/embeddingbag.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
KJTList,
5252
ShardedEmbeddingModule,
5353
)
54+
from torchrec.distributed.feature_score_utils import (
55+
create_sharding_type_to_feature_score_mapping,
56+
may_collect_feature_scores,
57+
)
5458
from torchrec.distributed.fused_params import (
5559
FUSED_PARAM_IS_SSD_TABLE,
5660
FUSED_PARAM_SSD_TABLE_LIST,
@@ -565,6 +569,24 @@ def __init__(
565569
# forward pass flow control
566570
self._has_uninitialized_input_dist: bool = True
567571
self._has_features_permute: bool = True
572+
573+
self._enable_feature_score_weight_accumulation: bool = False
574+
self._enabled_feature_score_auto_collection: bool = False
575+
self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
576+
(
577+
self._enable_feature_score_weight_accumulation,
578+
self._enabled_feature_score_auto_collection,
579+
self._sharding_type_feature_score_mapping,
580+
) = create_sharding_type_to_feature_score_mapping(
581+
self._embedding_bag_configs, self.sharding_type_to_sharding_infos
582+
)
583+
584+
logger.info(
585+
f"EBC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, "
586+
f"auto collection enabled: {self._enabled_feature_score_auto_collection}, "
587+
f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}"
588+
)
589+
568590
# Get all fused optimizers and combine them.
569591
optims = []
570592
for lookup in self._lookups:
@@ -1565,6 +1587,11 @@ def input_dist(
15651587
features_by_shards = features.split(
15661588
self._feature_splits,
15671589
)
1590+
features_by_shards = may_collect_feature_scores(
1591+
features_by_shards,
1592+
self._enabled_feature_score_auto_collection,
1593+
self._sharding_type_feature_score_mapping,
1594+
)
15681595
awaitables = []
15691596
for input_dist, features_by_shard, sharding_type in zip(
15701597
self._input_dists,

torchrec/distributed/feature_score_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torchrec.distributed.embedding_types import ShardingType
1818

1919
from torchrec.modules.embedding_configs import (
20-
EmbeddingConfig,
20+
BaseEmbeddingConfig,
2121
FeatureScoreBasedEvictionPolicy,
2222
)
2323
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -26,7 +26,7 @@
2626

2727

2828
def create_sharding_type_to_feature_score_mapping(
29-
embedding_configs: Sequence[EmbeddingConfig],
29+
embedding_configs: Sequence[BaseEmbeddingConfig],
3030
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]],
3131
) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]:
3232
enable_feature_score_weight_accumulation = False

0 commit comments

Comments
 (0)