Skip to content

Commit 389c70a

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Support no eviction in Feature score eviction policy (#5059)
Summary: X-link: meta-pytorch/torchrec#3488 X-link: facebookresearch/FBGEMM#2068 As title If one table is using feature score eviction in one tbe, then all tables in this tbe need to use the same policy. Feature score eviction can support ttl based eviction now. This diff is adding support no eviction in feature score eviction policy. Differential Revision: D84660528
1 parent 89f2bf4 commit 389c70a

File tree

7 files changed

+212
-6
lines changed

7 files changed

+212
-6
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class EvictionPolicy(NamedTuple):
120120
eviction_free_mem_check_interval_batch: Optional[int] = (
121121
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
122122
)
123+
enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
124+
None # enable eviction if eviction policy is feature score, false means no eviction
125+
)
123126

124127
def validate(self) -> None:
125128
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
@@ -217,13 +220,17 @@ def validate(self) -> None:
217220
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
218221
f"actual {self.threshold_calculation_bucket_num}"
219222
)
223+
assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
224+
"enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
225+
f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
226+
)
220227
assert (
221-
len(self.training_id_keep_count)
228+
len(self.enable_eviction_for_feature_score_eviction_policy)
229+
== len(self.training_id_keep_count)
222230
== len(self.feature_score_counter_decay_rates)
223-
== len(self.training_id_eviction_trigger_count)
224231
), (
225-
"feature_score_thresholds, training_id_eviction_trigger_count and training_id_keep_count must have the same length, "
226-
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.training_id_eviction_trigger_count}"
232+
"feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
233+
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
227234
)
228235

229236

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,15 @@ def __init__(
707707
# If trigger mode is free_mem(5), populate config
708708
self.set_free_mem_eviction_trigger_config(eviction_policy)
709709

710+
enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
711+
[
712+
int(x)
713+
for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
714+
]
715+
if eviction_policy.enable_eviction_for_feature_score_eviction_policy
716+
is not None
717+
else None
718+
)
710719
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
711720
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
712721
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
@@ -719,6 +728,7 @@ def __init__(
719728
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
720729
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
721730
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
731+
enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
722732
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
723733
table_dims.tolist() if table_dims is not None else None,
724734
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ void DramKVEmbeddingInferenceWrapper::init(
7070
std::nullopt /* feature_score_counter_decay_rates */,
7171
std::nullopt /* training_id_eviction_trigger_count */,
7272
std::nullopt /* training_id_keep_count */,
73+
std::nullopt /* enable_eviction_for_feature_score_eviction_policy */,
7374
std::nullopt /* l2_weight_thresholds */,
7475
std::nullopt /* embedding_dims */,
7576
std::nullopt /* threshold_calculation_bucket_stride */,

fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
110110
std::optional<std::vector<double>> feature_score_counter_decay_rates,
111111
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count,
112112
std::optional<std::vector<int64_t>> training_id_keep_count,
113+
std::optional<std::vector<int8_t>>
114+
enable_eviction_for_feature_score_eviction_policy, // 0: no eviction,
115+
// 1: evict
113116
std::optional<std::vector<double>> l2_weight_thresholds,
114117
std::optional<std::vector<int64_t>> embedding_dims,
115118
std::optional<double> threshold_calculation_bucket_stride = 0.2,
@@ -129,6 +132,8 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
129132
training_id_eviction_trigger_count_(
130133
std::move(training_id_eviction_trigger_count)),
131134
training_id_keep_count_(std::move(training_id_keep_count)),
135+
enable_eviction_for_feature_score_eviction_policy_(
136+
std::move(enable_eviction_for_feature_score_eviction_policy)),
132137
l2_weight_thresholds_(l2_weight_thresholds),
133138
embedding_dims_(embedding_dims),
134139
threshold_calculation_bucket_stride_(
@@ -212,6 +217,7 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
212217
CHECK(threshold_calculation_bucket_stride_.has_value());
213218
CHECK(threshold_calculation_bucket_num_.has_value());
214219
CHECK(ttls_in_mins_.has_value());
220+
CHECK(enable_eviction_for_feature_score_eviction_policy_.has_value());
215221
LOG(INFO) << "eviction config, trigger mode:"
216222
<< to_string(trigger_mode_) << eviction_trigger_stats_log
217223
<< ", strategy: " << to_string(trigger_strategy_)
@@ -223,7 +229,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
223229
<< ", threshold_calculation_bucket_num: "
224230
<< threshold_calculation_bucket_num_.value()
225231
<< ", feature_score_counter_decay_rates: "
226-
<< feature_score_counter_decay_rates_.value();
232+
<< feature_score_counter_decay_rates_.value()
233+
<< ", enable_eviction_for_feature_score_eviction_policy: "
234+
<< enable_eviction_for_feature_score_eviction_policy_.value();
227235
return;
228236
}
229237

@@ -281,6 +289,8 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
281289
std::optional<std::vector<double>> feature_score_counter_decay_rates_;
282290
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count_;
283291
std::optional<std::vector<int64_t>> training_id_keep_count_;
292+
std::optional<std::vector<int8_t>>
293+
enable_eviction_for_feature_score_eviction_policy_;
284294
std::optional<int64_t> total_id_eviction_trigger_count_;
285295
std::optional<std::vector<double>> l2_weight_thresholds_;
286296
std::optional<std::vector<int64_t>> embedding_dims_;
@@ -984,6 +994,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
984994
const std::vector<int64_t>& training_id_eviction_trigger_count,
985995
const std::vector<int64_t>& training_id_keep_count,
986996
const std::vector<int64_t>& ttls_in_mins,
997+
const std::vector<int8_t>&
998+
enable_eviction_for_feature_score_eviction_policy,
987999
const double threshold_calculation_bucket_stride,
9881000
const int64_t threshold_calculation_bucket_num,
9891001
int64_t interval_for_insufficient_eviction_s,
@@ -1003,6 +1015,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
10031015
training_id_eviction_trigger_count_(training_id_eviction_trigger_count),
10041016
training_id_keep_count_(training_id_keep_count),
10051017
ttls_in_mins_(ttls_in_mins),
1018+
enable_eviction_for_feature_score_eviction_policy_(
1019+
enable_eviction_for_feature_score_eviction_policy),
10061020
threshold_calculation_bucket_stride_(
10071021
threshold_calculation_bucket_stride),
10081022
num_buckets_(threshold_calculation_bucket_num),
@@ -1071,6 +1085,13 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
10711085
protected:
10721086
bool evict_block(weight_type* block, int sub_table_id, int shard_id)
10731087
override {
1088+
int8_t enable_eviction =
1089+
enable_eviction_for_feature_score_eviction_policy_[sub_table_id];
1090+
if (enable_eviction == 0) {
1091+
// If enable_eviction is set to 0, we don't evict any block.
1092+
return false;
1093+
}
1094+
10741095
double ttls_threshold = ttls_in_mins_[sub_table_id];
10751096
if (ttls_threshold > 0) {
10761097
auto current_time = FixedBlockPool::current_timestamp();
@@ -1145,6 +1166,15 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
11451166

11461167
void compute_thresholds_from_buckets() {
11471168
for (size_t table_id = 0; table_id < num_tables_; ++table_id) {
1169+
int8_t enable_eviction =
1170+
enable_eviction_for_feature_score_eviction_policy_[table_id];
1171+
if (enable_eviction == 0) {
1172+
// If enable_eviction is set to 0, we don't evict any block.
1173+
thresholds_[table_id] = 0.0;
1174+
evict_modes_[table_id] = EvictMode::NONE;
1175+
continue;
1176+
}
1177+
11481178
int64_t total = 0;
11491179

11501180
if (ttls_in_mins_[table_id] > 0) {
@@ -1209,7 +1239,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
12091239
<< " threshold bucket: " << threshold_bucket
12101240
<< " actual evict count: " << acc_count
12111241
<< " target evict count: " << evict_count
1212-
<< " total count: " << total;
1242+
<< " total count: " << total
1243+
<< " evict mode: " << to_string(evict_modes_[table_id]);
12131244

12141245
for (int table_id = 0; table_id < num_tables_; ++table_id) {
12151246
this->metrics_.eviction_threshold_with_dry_run[table_id] =
@@ -1226,6 +1257,16 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
12261257
THRESHOLD // blocks with scores below the computed threshold will be
12271258
// evicted
12281259
};
1260+
inline std::string to_string(EvictMode mode) {
1261+
switch (mode) {
1262+
case EvictMode::NONE:
1263+
return "NONE";
1264+
case EvictMode::ONLY_ZERO:
1265+
return "ONLY_ZERO";
1266+
case EvictMode::THRESHOLD:
1267+
return "THRESHOLD";
1268+
}
1269+
}
12291270
std::vector<EvictMode> evict_modes_;
12301271

12311272
const int num_tables_ = static_cast<int>(this->sub_table_hash_cumsum_.size());
@@ -1240,6 +1281,7 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
12401281
// eviction.
12411282

12421283
const std::vector<int64_t>& ttls_in_mins_; // Time-to-live for eviction.
1284+
const std::vector<int8_t>& enable_eviction_for_feature_score_eviction_policy_;
12431285
std::vector<std::vector<std::vector<size_t>>>
12441286
local_buckets_per_shard_per_table_;
12451287
std::vector<std::vector<size_t>> local_blocks_num_per_shard_per_table_;
@@ -1489,6 +1531,7 @@ std::unique_ptr<FeatureEvict<weight_type>> create_feature_evict(
14891531
config->training_id_eviction_trigger_count_.value(),
14901532
config->training_id_keep_count_.value(),
14911533
config->ttls_in_mins_.value(),
1534+
config->enable_eviction_for_feature_score_eviction_policy_.value(),
14921535
config->threshold_calculation_bucket_stride_.value(),
14931536
config->threshold_calculation_bucket_num_.value(),
14941537
config->interval_for_insufficient_eviction_s_,

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ static auto feature_evict_config =
737737
std::optional<std::vector<double>>,
738738
std::optional<std::vector<int64_t>>,
739739
std::optional<std::vector<int64_t>>,
740+
std::optional<std::vector<int8_t>>,
740741
std::optional<std::vector<double>>,
741742
std::optional<std::vector<int64_t>>,
742743
std::optional<double>,
@@ -756,6 +757,9 @@ static auto feature_evict_config =
756757
torch::arg("feature_score_counter_decay_rates") = std::nullopt,
757758
torch::arg("training_id_eviction_trigger_count") = std::nullopt,
758759
torch::arg("training_id_keep_count") = std::nullopt,
760+
torch::arg(
761+
"enable_eviction_for_feature_score_eviction_policy") =
762+
std::nullopt,
759763
torch::arg("l2_weight_thresholds") = std::nullopt,
760764
torch::arg("embedding_dims") = std::nullopt,
761765
torch::arg("threshold_calculation_bucket_stride") = 0.2,

0 commit comments

Comments
 (0)