Skip to content

Commit fb8932a

Browse files
iamzainhudameta-codesync[bot]
authored andcommitted
Add row based sharding support for FeaturedProcessedEBC (meta-pytorch#3406)
Summary: Pull Request resolved: meta-pytorch#3406 X-link: meta-pytorch#3281 In this diff we introduce row based sharding (TWRW, RW, GRID) type support for feature processors. Previously, feature processors did not support row based sharding since feature processors are data parallel. This means by splitting up the input for row based shards the accessed feature processor weights were in correct. In column/data sharding based approaches, the data is duplicated ensuring the correct weight is accessed across ranks. The indices/buckets are calculated post input split/distribution, to make it compatible with row based sharding we calculate this pre input split/distribution. This couples the train pipeline and feature processors. For each feature, we preprocess the input and place the calculated indices in KJT.weights, this propagates the indices correctly and indexs into the right weight to use for the final step in the feature processing. This applies in both pipelined and non pipelined situations - the input modification is done either at the pipelined forward call or in the input dist of the FPEBC. This is determined by the pipelining flag set through rewrite_model in train pipeline. Reviewed By: Fiery Differential Revision: D82248545 fbshipit-source-id: 4a8ad3bc6d14c684986ec0524ad787e964d8855a
1 parent 4f1f62d commit fb8932a

File tree

8 files changed

+330
-25
lines changed

8 files changed

+330
-25
lines changed

torchrec/distributed/fp_embeddingbag.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,18 @@
88
# pyre-strict
99

1010
from functools import partial
11-
from typing import Any, Dict, Iterator, List, Optional, Type, Union
11+
from typing import (
12+
Any,
13+
Dict,
14+
Iterator,
15+
List,
16+
Mapping,
17+
Optional,
18+
Tuple,
19+
Type,
20+
TypeVar,
21+
Union,
22+
)
1223

1324
import torch
1425
from torch import nn
@@ -31,14 +42,20 @@
3142
ShardingEnv,
3243
ShardingType,
3344
)
34-
from torchrec.distributed.utils import append_prefix, init_parameters
45+
from torchrec.distributed.utils import (
46+
append_prefix,
47+
init_parameters,
48+
modify_input_for_feature_processor,
49+
)
3550
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
3651
from torchrec.modules.fp_embedding_modules import (
3752
apply_feature_processors_to_kjt,
3853
FeatureProcessedEmbeddingBagCollection,
3954
)
4055
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
4156

57+
_T = TypeVar("_T")
58+
4259

4360
def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
4461
kt._values.add_(no_op_tensor)
@@ -74,6 +91,16 @@ def __init__(
7491
)
7592
)
7693

94+
self._row_wise_sharded: bool = False
95+
for param_sharding in table_name_to_parameter_sharding.values():
96+
if param_sharding.sharding_type in [
97+
ShardingType.ROW_WISE.value,
98+
ShardingType.TABLE_ROW_WISE.value,
99+
ShardingType.GRID_SHARD.value,
100+
]:
101+
self._row_wise_sharded = True
102+
break
103+
77104
self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups
78105

79106
self._is_collection: bool = False
@@ -96,6 +123,11 @@ def __init__(
96123
def input_dist(
97124
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
98125
) -> Awaitable[Awaitable[KJTList]]:
126+
if not self.is_pipelined and self._row_wise_sharded:
127+
# transform input to support row based sharding when not pipelined
128+
modify_input_for_feature_processor(
129+
features, self._feature_processors, self._is_collection
130+
)
99131
return self._embedding_bag_collection.input_dist(ctx, features)
100132

101133
def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
@@ -105,10 +137,7 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
105137
kjt_list.append(self._feature_processors(features))
106138
else:
107139
kjt_list.append(
108-
apply_feature_processors_to_kjt(
109-
features,
110-
self._feature_processors,
111-
)
140+
apply_feature_processors_to_kjt(features, self._feature_processors)
112141
)
113142
return KJTList(kjt_list)
114143

@@ -117,7 +146,6 @@ def compute(
117146
ctx: EmbeddingBagCollectionContext,
118147
dist_input: KJTList,
119148
) -> List[torch.Tensor]:
120-
121149
fp_features = self.apply_feature_processors_to_kjt_list(dist_input)
122150
return self._embedding_bag_collection.compute(ctx, fp_features)
123151

@@ -166,6 +194,18 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
166194
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
167195
self._embedding_bag_collection._initialize_torch_state(skip_registering)
168196

197+
def preprocess_input(
198+
self, args: List[_T], kwargs: Mapping[str, _T]
199+
) -> Tuple[List[_T], Mapping[str, _T]]:
200+
for x in args + list(kwargs.values()):
201+
if isinstance(x, KeyedJaggedTensor):
202+
modify_input_for_feature_processor(
203+
features=x,
204+
feature_processors=self._feature_processors,
205+
is_collection=self._is_collection,
206+
)
207+
return args, kwargs
208+
169209

170210
class FeatureProcessedEmbeddingBagCollectionSharder(
171211
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
@@ -191,7 +231,6 @@ def shard(
191231
device: Optional[torch.device] = None,
192232
module_fqn: Optional[str] = None,
193233
) -> ShardedFeatureProcessedEmbeddingBagCollection:
194-
195234
if device is None:
196235
device = torch.device("cuda")
197236

@@ -228,12 +267,14 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
228267
if compute_device_type in {"mtia"}:
229268
return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]
230269

231-
# No row wise because position weighted FP and RW don't play well together.
232270
types = [
233271
ShardingType.DATA_PARALLEL.value,
234272
ShardingType.TABLE_WISE.value,
235273
ShardingType.COLUMN_WISE.value,
236274
ShardingType.TABLE_COLUMN_WISE.value,
275+
ShardingType.TABLE_ROW_WISE.value,
276+
ShardingType.ROW_WISE.value,
277+
ShardingType.GRID_SHARD.value,
237278
]
238279

239280
return types

torchrec/distributed/tests/test_fp_embeddingbag.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase):
231231
def test_sharding_ebc(
232232
self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool
233233
) -> None:
234-
235234
import hypothesis
236235

237236
# don't need to test entire matrix

torchrec/distributed/tests/test_fp_embeddingbag_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
8686
pred = torch.cat(
8787
[
8888
fp_ebc_out[key]
89-
for key in ["feature_0", "feature_1", "feature_2", "feature_3"]
89+
for key in [
90+
"feature_0",
91+
"feature_1",
92+
"feature_2",
93+
"feature_3",
94+
]
9095
],
9196
dim=1,
9297
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from torch._dynamo.testing import reduce_to_scalar_loss
2323
from torch._dynamo.utils import counters
2424
from torchrec.distributed import DistributedModelParallel
25-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
25+
from torchrec.distributed.embedding_types import (
26+
EmbeddingComputeKernel,
27+
EmbeddingTableConfig,
28+
)
2629
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
2730
from torchrec.distributed.fp_embeddingbag import (
2831
FeatureProcessedEmbeddingBagCollectionSharder,
@@ -31,8 +34,13 @@
3134
from torchrec.distributed.model_parallel import DMPCollection
3235
from torchrec.distributed.sharding_plan import (
3336
construct_module_sharding_plan,
37+
row_wise,
3438
table_wise,
3539
)
40+
from torchrec.distributed.test_utils.multi_process import (
41+
MultiProcessContext,
42+
MultiProcessTestBase,
43+
)
3644
from torchrec.distributed.test_utils.test_model import (
3745
ModelInput,
3846
TestNegSamplingModule,
@@ -333,6 +341,161 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
333341
torch.testing.assert_close(pred_gpu.cpu(), pred)
334342

335343

344+
def fp_ebc_rw_sharding_test_runner(
345+
rank: int,
346+
world_size: int,
347+
tables: List[EmbeddingTableConfig],
348+
weighted_tables: List[EmbeddingTableConfig],
349+
data: List[Tuple[ModelInput, List[ModelInput]]],
350+
backend: str = "nccl",
351+
local_size: Optional[int] = None,
352+
) -> None:
353+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
354+
assert ctx.pg is not None
355+
sharder = cast(
356+
ModuleSharder[nn.Module],
357+
FeatureProcessedEmbeddingBagCollectionSharder(),
358+
)
359+
360+
class DummyWrapper(nn.Module):
361+
def __init__(self, sparse_arch):
362+
super().__init__()
363+
self.m = sparse_arch
364+
365+
def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
366+
return self.m(model_input.idlist_features)
367+
368+
max_feature_lengths = [10, 10, 12, 12]
369+
sparse_arch = DummyWrapper(
370+
create_module_and_freeze(
371+
tables=tables, # pyre-ignore[6]
372+
device=ctx.device,
373+
use_fp_collection=False,
374+
max_feature_lengths=max_feature_lengths,
375+
)
376+
)
377+
378+
# compute_kernel = EmbeddingComputeKernel.FUSED.value
379+
module_sharding_plan = construct_module_sharding_plan(
380+
sparse_arch.m._fp_ebc,
381+
per_param_sharding={
382+
"table_0": row_wise(),
383+
"table_1": row_wise(),
384+
"table_2": row_wise(),
385+
"table_3": row_wise(),
386+
},
387+
world_size=2,
388+
device_type=ctx.device.type,
389+
sharder=sharder,
390+
)
391+
sharded_sparse_arch_pipeline = DistributedModelParallel(
392+
module=copy.deepcopy(sparse_arch),
393+
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
394+
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
395+
sharders=[sharder],
396+
device=ctx.device,
397+
)
398+
sharded_sparse_arch_no_pipeline = DistributedModelParallel(
399+
module=copy.deepcopy(sparse_arch),
400+
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
401+
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
402+
sharders=[sharder],
403+
device=ctx.device,
404+
)
405+
406+
batches = []
407+
for d in data:
408+
batches.append(d[1][ctx.rank].to(ctx.device))
409+
dataloader = iter(batches)
410+
411+
optimizer_no_pipeline = optim.SGD(
412+
sharded_sparse_arch_no_pipeline.parameters(), lr=0.1
413+
)
414+
optimizer_pipeline = optim.SGD(
415+
sharded_sparse_arch_pipeline.parameters(), lr=0.1
416+
)
417+
418+
pipeline = TrainPipelineSparseDist(
419+
sharded_sparse_arch_pipeline,
420+
optimizer_pipeline,
421+
ctx.device,
422+
)
423+
424+
for batch in batches[:-2]:
425+
batch = batch.to(ctx.device)
426+
optimizer_no_pipeline.zero_grad()
427+
loss, pred = sharded_sparse_arch_no_pipeline(batch)
428+
loss.backward()
429+
optimizer_no_pipeline.step()
430+
431+
pred_pipeline = pipeline.progress(dataloader)
432+
torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu())
433+
434+
435+
class TrainPipelineGPUTest(MultiProcessTestBase):
436+
def setUp(self, backend: str = "nccl") -> None:
437+
super().setUp()
438+
439+
self.pipeline_class = TrainPipelineSparseDist
440+
num_features = 4
441+
num_weighted_features = 4
442+
self.tables = [
443+
EmbeddingBagConfig(
444+
num_embeddings=(i + 1) * 100,
445+
embedding_dim=(i + 1) * 4,
446+
name="table_" + str(i),
447+
feature_names=["feature_" + str(i)],
448+
)
449+
for i in range(num_features)
450+
]
451+
self.weighted_tables = [
452+
EmbeddingBagConfig(
453+
num_embeddings=(i + 1) * 100,
454+
embedding_dim=(i + 1) * 4,
455+
name="weighted_table_" + str(i),
456+
feature_names=["weighted_feature_" + str(i)],
457+
)
458+
for i in range(num_weighted_features)
459+
]
460+
461+
self.backend = backend
462+
if torch.cuda.is_available():
463+
self.device = torch.device("cuda")
464+
else:
465+
self.device = torch.device("cpu")
466+
467+
if self.backend == "nccl" and self.device == torch.device("cpu"):
468+
self.skipTest("NCCL not supported on CPUs.")
469+
470+
def _generate_data(
471+
self,
472+
num_batches: int = 5,
473+
batch_size: int = 1,
474+
max_feature_lengths: Optional[List[int]] = None,
475+
) -> List[Tuple[ModelInput, List[ModelInput]]]:
476+
return [
477+
ModelInput.generate(
478+
tables=self.tables,
479+
weighted_tables=self.weighted_tables,
480+
batch_size=batch_size,
481+
world_size=2,
482+
num_float_features=10,
483+
max_feature_lengths=max_feature_lengths,
484+
)
485+
for i in range(num_batches)
486+
]
487+
488+
def test_fp_ebc_rw(self) -> None:
489+
data = self._generate_data(max_feature_lengths=[10, 10, 12, 12])
490+
self._run_multi_process_test(
491+
callable=fp_ebc_rw_sharding_test_runner,
492+
world_size=2,
493+
tables=self.tables,
494+
weighted_tables=self.weighted_tables,
495+
data=data,
496+
)
497+
498+
336499
class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
337500
# pyre-fixme[56]: Pyre was not able to infer the type of argument
338501
@unittest.skipIf(

torchrec/distributed/train_pipeline/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _start_data_dist(
147147
# and this info was done in the _rewrite_model by tracing the
148148
# entire model to get the arg_info_list
149149
args, kwargs = forward.args.build_args_kwargs(batch)
150+
args, kwargs = module.preprocess_input(args, kwargs)
150151

151152
# Start input distribution.
152153
module_ctx = module.create_context()
@@ -382,6 +383,8 @@ def _rewrite_model( # noqa C901
382383
logger.info(f"Module '{node.target}' will be pipelined")
383384
child = sharded_modules[node.target]
384385
original_forwards.append(child.forward)
386+
# Set pipelining flag on the child module
387+
child.is_pipelined = True
385388
# pyre-ignore[8] Incompatible attribute type
386389
child.forward = pipelined_forward(
387390
node.target,

0 commit comments

Comments
 (0)