Skip to content

Commit a9fcb17

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT (#2952)
Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Differential Revision: D74295924
1 parent bc78a4c commit a9fcb17

File tree

2 files changed

+72
-43
lines changed

2 files changed

+72
-43
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,14 @@ def forward(
207207
num_embeddings=10,
208208
feature_names=["f2"],
209209
)
210+
config3 = EmbeddingBagConfig(
211+
name="t3",
212+
embedding_dim=5,
213+
num_embeddings=10,
214+
feature_names=["f3"],
215+
)
210216
ebc = EmbeddingBagCollection(
211-
tables=[config1, config2],
217+
tables=[config1, config2, config3],
212218
is_weighted=False,
213219
)
214220

@@ -293,42 +299,60 @@ def test_serialize_deserialize_ebc(self) -> None:
293299
self.assertEqual(deserialized.shape, orginal.shape)
294300
self.assertTrue(torch.allclose(deserialized, orginal))
295301

296-
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
297302
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
298303
model = self.generate_model_for_vbe_kjt()
299-
id_list_features = KeyedJaggedTensor(
300-
keys=["f1", "f2"],
301-
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
302-
lengths=torch.tensor([3, 3, 2]),
303-
stride_per_key_per_rank=[[2], [1]],
304-
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
304+
kjt_1 = KeyedJaggedTensor(
305+
keys=["f1", "f2", "f3"],
306+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
307+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
308+
stride_per_key_per_rank=torch.tensor([[3], [2], [1]]),
309+
inverse_indices=(
310+
["f1", "f2", "f3"],
311+
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
312+
),
313+
)
314+
kjt_2 = KeyedJaggedTensor(
315+
keys=["f1", "f2", "f3"],
316+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
317+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
318+
stride_per_key_per_rank=torch.tensor([[1], [2], [3]]),
319+
inverse_indices=(
320+
["f1", "f2", "f3"],
321+
torch.tensor([[0, 0, 0], [0, 1, 0], [0, 1, 2]]),
322+
),
305323
)
306324

307-
eager_out = model(id_list_features)
325+
eager_out = model(kjt_1)
326+
eager_out_2 = model(kjt_2)
308327

309328
# Serialize EBC
310329
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
311330
ep = torch.export.export(
312331
model,
313-
(id_list_features,),
332+
(kjt_1,),
314333
{},
315334
strict=False,
316335
# Allows KJT to not be unflattened and run a forward on unflattened EP
317336
preserve_module_call_signature=(tuple(sparse_fqns)),
318337
)
319338

320339
# Run forward on ExportedProgram
321-
ep_output = ep.module()(id_list_features)
340+
ep_output = ep.module()(kjt_1)
341+
ep_output_2 = ep.module()(kjt_2)
322342

343+
self.assertEqual(len(ep_output), len(kjt_1.keys()))
344+
self.assertEqual(len(ep_output_2), len(kjt_2.keys()))
323345
for i, tensor in enumerate(ep_output):
324-
self.assertEqual(eager_out[i].shape, tensor.shape)
346+
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
347+
for i, tensor in enumerate(ep_output_2):
348+
self.assertEqual(eager_out_2[i].shape[1], tensor.shape[1])
325349

326350
# Deserialize EBC
327351
unflatten_ep = torch.export.unflatten(ep)
328352
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
329353

330354
# check EBC config
331-
for i in range(5):
355+
for i in range(1):
332356
ebc_name = f"ebc{i + 1}"
333357
self.assertIsInstance(
334358
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
@@ -343,36 +367,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
343367
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
344368
self.assertEqual(deserialized.feature_names, orginal.feature_names)
345369

346-
# check FPEBC config
347-
for i in range(2):
348-
fpebc_name = f"fpebc{i + 1}"
349-
assert isinstance(
350-
getattr(deserialized_model, fpebc_name),
351-
FeatureProcessedEmbeddingBagCollection,
352-
)
353-
354-
for deserialized, orginal in zip(
355-
getattr(
356-
deserialized_model, fpebc_name
357-
)._embedding_bag_collection.embedding_bag_configs(),
358-
getattr(
359-
model, fpebc_name
360-
)._embedding_bag_collection.embedding_bag_configs(),
361-
):
362-
self.assertEqual(deserialized.name, orginal.name)
363-
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
364-
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
365-
self.assertEqual(deserialized.feature_names, orginal.feature_names)
366-
367370
# Run forward on deserialized model and compare the output
368371
deserialized_model.load_state_dict(model.state_dict())
369-
deserialized_out = deserialized_model(id_list_features)
372+
deserialized_out = deserialized_model(kjt_1)
370373

371374
self.assertEqual(len(deserialized_out), len(eager_out))
372375
for deserialized, orginal in zip(deserialized_out, eager_out):
373376
self.assertEqual(deserialized.shape, orginal.shape)
374377
self.assertTrue(torch.allclose(deserialized, orginal))
375378

379+
deserialized_out_2 = deserialized_model(kjt_2)
380+
381+
self.assertEqual(len(deserialized_out_2), len(eager_out_2))
382+
for deserialized, orginal in zip(deserialized_out_2, eager_out_2):
383+
self.assertEqual(deserialized.shape, orginal.shape)
384+
self.assertTrue(torch.allclose(deserialized, orginal))
385+
376386
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
377387
model = self.generate_model()
378388
feature1 = KeyedJaggedTensor.from_offsets_sync(

torchrec/sparse/jagged_tensor.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,6 +1728,8 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17281728
"_weights",
17291729
"_lengths",
17301730
"_offsets",
1731+
"_stride_per_key_per_rank",
1732+
"_inverse_indices",
17311733
]
17321734

17331735
def __init__(
@@ -3021,13 +3023,19 @@ def dist_init(
30213023

30223024
def _kjt_flatten(
30233025
t: KeyedJaggedTensor,
3024-
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3025-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3026+
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], List[str]]]:
3027+
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3028+
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3029+
3030+
return values, (
3031+
t._keys,
3032+
t._inverse_indices[0] if t._inverse_indices is not None else [],
3033+
)
30263034

30273035

30283036
def _kjt_flatten_with_keys(
30293037
t: KeyedJaggedTensor,
3030-
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
3038+
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], Tuple[List[str], List[str]]]:
30313039
values, context = _kjt_flatten(t)
30323040
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
30333041
return [ # pyre-ignore[7]
@@ -3036,15 +3044,26 @@ def _kjt_flatten_with_keys(
30363044

30373045

30383046
def _kjt_unflatten(
3039-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3047+
values: List[Optional[torch.Tensor]],
3048+
context: Tuple[
3049+
List[str], List[str]
3050+
], # context is the (_keys, _inverse_indices[0]) tuple
30403051
) -> KeyedJaggedTensor:
3041-
return KeyedJaggedTensor(context, *values)
3052+
return KeyedJaggedTensor(
3053+
context[0],
3054+
*values[:-2],
3055+
stride_per_key_per_rank=values[-2],
3056+
inverse_indices=(context[1], values[-1]) if values[-1] is not None else None,
3057+
)
30423058

30433059

30443060
def _kjt_flatten_spec(
30453061
t: KeyedJaggedTensor, spec: TreeSpec
30463062
) -> List[Optional[torch.Tensor]]:
3047-
return [getattr(t, a) for a in KeyedJaggedTensor._fields]
3063+
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3064+
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3065+
3066+
return values
30483067

30493068

30503069
register_pytree_node(
@@ -3059,7 +3078,7 @@ def _kjt_flatten_spec(
30593078

30603079
def flatten_kjt_list(
30613080
kjt_arr: List[KeyedJaggedTensor],
3062-
) -> Tuple[List[Optional[torch.Tensor]], List[List[str]]]:
3081+
) -> Tuple[List[Optional[torch.Tensor]], List[Tuple[List[str], List[str]]]]:
30633082
_flattened_data = []
30643083
_flattened_context = []
30653084
for t in kjt_arr:
@@ -3070,7 +3089,7 @@ def flatten_kjt_list(
30703089

30713090

30723091
def unflatten_kjt_list(
3073-
values: List[Optional[torch.Tensor]], contexts: List[List[str]]
3092+
values: List[Optional[torch.Tensor]], contexts: List[Tuple[List[str], List[str]]]
30743093
) -> List[KeyedJaggedTensor]:
30753094
num_kjt_fields = len(KeyedJaggedTensor._fields)
30763095
length = len(values)

0 commit comments

Comments
 (0)