Skip to content

Commit 704be8d

Browse files
Hanna Xufacebook-github-bot
authored andcommitted
Support DI sharding for FPE_EBC
Summary: Support models that have FeatureProcessedEmbeddingBagCollection. These changes make sure we add QuantFeatureProcessedEmbeddingBagCollectionSharder as a recognized sharder, handle multiple envs needed for specifying DI sharding, and propagate TBE properly when processing the sharding plan. Differential Revision: D74671655
1 parent d420392 commit 704be8d

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

torchrec/distributed/quant_embeddingbag.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def __init__(
430430
self,
431431
module: EmbeddingBagCollectionInterface,
432432
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
433-
env: ShardingEnv,
433+
env: Union[ShardingEnv, Dict[str, ShardingEnv]], # support for hybrid sharding
434434
fused_params: Optional[Dict[str, Any]] = None,
435435
device: Optional[torch.device] = None,
436436
feature_processor: Optional[FeatureProcessorsCollection] = None,
@@ -462,11 +462,22 @@ def __init__(
462462
f"Feature processor has inconsistent devices. Expected {feature_processor_device}, got {param.device}"
463463
)
464464

465+
total_world_size = 0
466+
if isinstance(env, Dict):
467+
for _, device_group in self._sharding_type_device_group_to_sharding.keys():
468+
devices = (
469+
device_group if isinstance(device_group, tuple) else (device_group,)
470+
)
471+
for d in devices:
472+
total_world_size += env[d].world_size
473+
else:
474+
total_world_size = env.world_size
475+
465476
if feature_processor_device is None:
466-
for _ in range(env.world_size):
477+
for _ in range(total_world_size):
467478
self.feature_processors_per_rank.append(feature_processor)
468479
else:
469-
for i in range(env.world_size):
480+
for i in range(total_world_size):
470481
# Generic copy, for example initailized on cpu but -> sharding as meta
471482
self.feature_processors_per_rank.append(
472483
copy.deepcopy(feature_processor)

torchrec/distributed/quant_state.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,16 @@ def get_bucket_offsets_per_virtual_table(
415415
}
416416

417417

418+
def get_param_id_from_type(is_sqebc: bool, is_sqmcec: bool, is_sfpebc: bool) -> str:
419+
if is_sqebc:
420+
return "embedding_bags"
421+
elif is_sqmcec:
422+
return "_embedding_module.embeddings"
423+
elif is_sfpebc:
424+
return "_embedding_bag_collection.embedding_bags"
425+
return "embeddings"
426+
427+
418428
def sharded_tbes_weights_spec(
419429
sharded_model: torch.nn.Module,
420430
virtual_table_name_to_bucket_lengths: Optional[Dict[str, list[int]]] = None,
@@ -450,11 +460,14 @@ def sharded_tbes_weights_spec(
450460
is_sqebc: bool = "ShardedQuantEmbeddingBagCollection" in type_name
451461
is_sqec: bool = "ShardedQuantEmbeddingCollection" in type_name
452462
is_sqmcec: bool = "ShardedQuantManagedCollisionEmbeddingCollection" in type_name
463+
is_sfpebc: bool = (
464+
"ShardedQuantFeatureProcessedEmbeddingBagCollection" in type_name
465+
)
453466

454-
if is_sqebc or is_sqec or is_sqmcec:
467+
if is_sqebc or is_sqec or is_sqmcec or is_sfpebc:
455468
assert (
456-
is_sqec + is_sqebc + is_sqmcec == 1
457-
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection and ShardedQuantManagedCollisionEmbeddingCollection are true"
469+
is_sqec + is_sqebc + is_sqmcec + is_sfpebc == 1
470+
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection and ShardedQuantFeatureProcessedEmbeddingBagCollection are true"
458471
tbes_configs: Dict[
459472
IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig
460473
] = module.tbes_configs()
@@ -546,8 +559,7 @@ def sharded_tbes_weights_spec(
546559
row_offsets,
547560
table_metadata.shard_offsets[1],
548561
]
549-
s: str = "embedding_bags" if is_sqebc else "embeddings"
550-
s = ("_embedding_module." if is_sqmcec else "") + s
562+
s: str = get_param_id_from_type(is_sqebc, is_sqmcec, is_sfpebc)
551563
unsharded_fqn_weight_prefix: str = f"{module_fqn}.{s}.{table_name}"
552564
unsharded_fqn_weight: str = unsharded_fqn_weight_prefix + ".weight"
553565

0 commit comments

Comments
 (0)