diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index 01b904781..ab767a317 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -430,7 +430,7 @@ def __init__( self, module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], - env: ShardingEnv, + env: Union[ShardingEnv, Dict[str, ShardingEnv]], # support for hybrid sharding fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, feature_processor: Optional[FeatureProcessorsCollection] = None, @@ -462,11 +462,30 @@ def __init__( f"Feature processor has inconsistent devices. Expected {feature_processor_device}, got {param.device}" ) + if isinstance(env, Dict): + expected_device_type = "cuda" + for ( + embedding_configs + ) in self._sharding_type_device_group_to_sharding_infos.values(): + # throws if not all shards only have the expected device type + shard_device_type = get_device_from_sharding_infos(embedding_configs) + if isinstance(shard_device_type, tuple): + raise RuntimeError( + f"Sharding across multiple device types for FeatureProcessedEmbeddingBagCollection is not supported yet, got {shard_device_type}", + ) + assert ( + shard_device_type == expected_device_type + ), f"Expected {expected_device_type} but got {shard_device_type} for FeatureProcessedEmbeddingBagCollection sharding device type" + + # TODO(hcxu): support hybrid sharding with feature_processors_per_rank: ModuleList(ModuleList()), if compatible + world_size = env[expected_device_type].world_size + else: + world_size = env.world_size + if feature_processor_device is None: - for _ in range(env.world_size): - self.feature_processors_per_rank.append(feature_processor) + self.feature_processors_per_rank += [feature_processor] * world_size else: - for i in range(env.world_size): + for i in range(world_size): # Generic copy, for example initailized on cpu but -> sharding as meta self.feature_processors_per_rank.append( copy.deepcopy(feature_processor) diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 833ff494c..333f04a71 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -419,6 +419,16 @@ def get_bucket_offsets_per_virtual_table( } +def get_param_id_from_type(is_sqebc: bool, is_sqmcec: bool, is_sfpebc: bool) -> str: + if is_sqebc: + return "embedding_bags" + elif is_sqmcec: + return "_embedding_module.embeddings" + elif is_sfpebc: + return "_embedding_bag_collection.embedding_bags" + return "embeddings" + + def sharded_tbes_weights_spec( sharded_model: torch.nn.Module, virtual_table_name_to_bucket_lengths: Optional[Dict[str, list[int]]] = None, @@ -454,11 +464,14 @@ def sharded_tbes_weights_spec( is_sqebc: bool = "ShardedQuantEmbeddingBagCollection" in type_name is_sqec: bool = "ShardedQuantEmbeddingCollection" in type_name is_sqmcec: bool = "ShardedQuantManagedCollisionEmbeddingCollection" in type_name + is_sfpebc: bool = ( + "ShardedQuantFeatureProcessedEmbeddingBagCollection" in type_name + ) - if is_sqebc or is_sqec or is_sqmcec: + if is_sqebc or is_sqec or is_sqmcec or is_sfpebc: assert ( - is_sqec + is_sqebc + is_sqmcec == 1 - ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection and ShardedQuantManagedCollisionEmbeddingCollection are true" + is_sqec + is_sqebc + is_sqmcec + is_sfpebc == 1 + ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection and ShardedQuantFeatureProcessedEmbeddingBagCollection are true" tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig ] = module.tbes_configs() @@ -550,8 +563,7 @@ def sharded_tbes_weights_spec( row_offsets, table_metadata.shard_offsets[1], ] - s: str = "embedding_bags" if is_sqebc else "embeddings" - s = ("_embedding_module." if is_sqmcec else "") + s + s: str = get_param_id_from_type(is_sqebc, is_sqmcec, is_sfpebc) unsharded_fqn_weight_prefix: str = f"{module_fqn}.{s}.{table_name}" unsharded_fqn_weight: str = unsharded_fqn_weight_prefix + ".weight"