Skip to content

Commit 7f793b8

Browse files
emlinfacebook-github-bot
authored andcommitted
add auto feature score collection to EC (pytorch#5030)
Summary: X-link: meta-pytorch/torchrec#3474 X-link: facebookresearch/FBGEMM#2043 Enable feature score auto collection in ShardedEmbeddingCollection based on static feature to score mapping. If user needs custom score for specific id, they can disable auto collection and then change model code explicitly to collect score for each id. Here is the sample eviction policy config in embedding_table config to enable auto score collection: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Additionally the counter collected previously during EC dedup is not used by kvzch backend, so this diff removed that counter and allow KJT to transfer a single float32 weight tensor to backend. This allows feature score collection for EBC since there could have another float weight for EBC pooling already. Reviewed By: EddyLXJ Differential Revision: D83945722
1 parent 61b22f5 commit 7f793b8

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2088,7 +2088,7 @@ def _prefetch( # noqa C901
20882088
torch.tensor(
20892089
[weights.shape[0]], device="cpu", dtype=torch.long
20902090
),
2091-
weights.cpu().view(torch.float32).view(-1, 2),
2091+
weights.cpu(),
20922092
)
20932093

20942094
# Generate row addresses (pointing to either L1 or the current

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,6 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
768768
CHECK_EQ(indices.size(0), engege_rates.size(0));
769769
auto indices_data_ptr = indices.data_ptr<index_t>();
770770
auto engage_rate_ptr = engege_rates.data_ptr<float>();
771-
int64_t stride = 2;
772771
{
773772
auto before_write_lock_ts =
774773
facebook::WallClockUtil::NowInUsecFast();
@@ -783,8 +782,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
783782
index_iter++) {
784783
const auto& id_index = *index_iter;
785784
auto id = int64_t(indices_data_ptr[id_index]);
786-
float engege_rate =
787-
float(engage_rate_ptr[id_index * stride + 0]);
785+
float engege_rate = float(engage_rate_ptr[id_index]);
788786
// use mempool
789787
weight_type* block = nullptr;
790788
auto before_lookup_cache_ts =

0 commit comments

Comments
 (0)