11
11
12
12
import copy
13
13
14
- from dataclasses import dataclass
14
+ from dataclasses import dataclass , field
15
15
from typing import Any , cast , Dict , List , Optional , Tuple , Type , Union
16
16
17
17
import click
26
26
from torchrec .distributed .comm import get_local_size
27
27
from torchrec .distributed .embedding_types import EmbeddingComputeKernel
28
28
from torchrec .distributed .planner import EmbeddingShardingPlanner , Topology
29
+ from torchrec .distributed .planner .constants import NUM_POOLINGS , POOLING_FACTOR
29
30
from torchrec .distributed .planner .planners import HeteroEmbeddingShardingPlanner
30
31
from torchrec .distributed .planner .types import ParameterConstraints
31
32
@@ -80,6 +81,9 @@ class RunOptions:
80
81
planner_type (str): Type of sharding planner to use. Options are:
81
82
- "embedding": EmbeddingShardingPlanner (default)
82
83
- "hetero": HeteroEmbeddingShardingPlanner
84
+ pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table.
85
+ This is the average number of values each sample has for the feature.
86
+ num_poolings (Optional[List[float]]): Number of poolings for each feature of the table.
83
87
"""
84
88
85
89
world_size : int = 2
@@ -89,6 +93,8 @@ class RunOptions:
89
93
input_type : str = "kjt"
90
94
profile : str = ""
91
95
planner_type : str = "embedding"
96
+ pooling_factors : Optional [List [float ]] = None
97
+ num_poolings : Optional [List [float ]] = None
92
98
93
99
94
100
@dataclass
@@ -111,7 +117,7 @@ class EmbeddingTablesConfig:
111
117
112
118
num_unweighted_features : int = 100
113
119
num_weighted_features : int = 100
114
- embedding_feature_dim : int = 512
120
+ embedding_feature_dim : int = 128
115
121
116
122
def generate_tables (
117
123
self ,
@@ -286,17 +292,36 @@ def _generate_planner(
286
292
tables : Optional [List [EmbeddingBagConfig ]],
287
293
weighted_tables : Optional [List [EmbeddingBagConfig ]],
288
294
sharding_type : ShardingType ,
289
- compute_kernel : EmbeddingComputeKernel = EmbeddingComputeKernel .FUSED ,
295
+ compute_kernel : EmbeddingComputeKernel ,
296
+ num_batches : int ,
297
+ batch_size : int ,
298
+ pooling_factors : Optional [List [float ]],
299
+ num_poolings : Optional [List [float ]],
290
300
) -> Union [EmbeddingShardingPlanner , HeteroEmbeddingShardingPlanner ]:
291
301
# Create parameter constraints for tables
292
302
constraints = {}
293
303
304
+ if pooling_factors is None :
305
+ pooling_factors = [POOLING_FACTOR ] * num_batches
306
+
307
+ if num_poolings is None :
308
+ num_poolings = [NUM_POOLINGS ] * num_batches
309
+
310
+ batch_sizes = [batch_size ] * num_batches
311
+
312
+ assert (
313
+ len (pooling_factors ) == num_batches and len (num_poolings ) == num_batches
314
+ ), "The length of pooling_factors and num_poolings must match the number of batches."
315
+
294
316
if tables is not None :
295
317
for table in tables :
296
318
constraints [table .name ] = ParameterConstraints (
297
319
sharding_types = [sharding_type .value ],
298
320
compute_kernels = [compute_kernel .value ],
299
321
device_group = "cuda" ,
322
+ pooling_factors = pooling_factors ,
323
+ num_poolings = num_poolings ,
324
+ batch_sizes = batch_sizes ,
300
325
)
301
326
302
327
if weighted_tables is not None :
@@ -305,6 +330,10 @@ def _generate_planner(
305
330
sharding_types = [sharding_type .value ],
306
331
compute_kernels = [compute_kernel .value ],
307
332
device_group = "cuda" ,
333
+ pooling_factors = pooling_factors ,
334
+ num_poolings = num_poolings ,
335
+ batch_sizes = batch_sizes ,
336
+ is_weighted = True ,
308
337
)
309
338
310
339
if planner_type == "embedding" :
@@ -413,6 +442,10 @@ def runner(
413
442
weighted_tables = weighted_tables ,
414
443
sharding_type = run_option .sharding_type ,
415
444
compute_kernel = run_option .compute_kernel ,
445
+ num_batches = run_option .num_batches ,
446
+ batch_size = input_config .batch_size ,
447
+ pooling_factors = run_option .pooling_factors ,
448
+ num_poolings = run_option .num_poolings ,
416
449
)
417
450
418
451
sharded_model , optimizer = _generate_sharded_model_and_optimizer (
0 commit comments