@@ -207,8 +207,14 @@ def forward(
207
207
num_embeddings = 10 ,
208
208
feature_names = ["f2" ],
209
209
)
210
+ config3 = EmbeddingBagConfig (
211
+ name = "t3" ,
212
+ embedding_dim = 5 ,
213
+ num_embeddings = 10 ,
214
+ feature_names = ["f3" ],
215
+ )
210
216
ebc = EmbeddingBagCollection (
211
- tables = [config1 , config2 ],
217
+ tables = [config1 , config2 , config3 ],
212
218
is_weighted = False ,
213
219
)
214
220
@@ -293,42 +299,60 @@ def test_serialize_deserialize_ebc(self) -> None:
293
299
self .assertEqual (deserialized .shape , orginal .shape )
294
300
self .assertTrue (torch .allclose (deserialized , orginal ))
295
301
296
- @unittest .skip ("Adding test for demonstrating VBE KJT flattening issue for now." )
297
302
def test_serialize_deserialize_ebc_with_vbe_kjt (self ) -> None :
298
303
model = self .generate_model_for_vbe_kjt ()
299
- id_list_features = KeyedJaggedTensor (
300
- keys = ["f1" , "f2" ],
301
- values = torch .tensor ([5 , 6 , 7 , 1 , 2 , 3 , 0 , 1 ]),
302
- lengths = torch .tensor ([3 , 3 , 2 ]),
303
- stride_per_key_per_rank = [[2 ], [1 ]],
304
- inverse_indices = (["f1" , "f2" ], torch .tensor ([[0 , 1 , 0 ], [0 , 0 , 0 ]])),
304
+ kjt_1 = KeyedJaggedTensor (
305
+ keys = ["f1" , "f2" , "f3" ],
306
+ values = torch .tensor ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ]),
307
+ lengths = torch .tensor ([1 , 2 , 3 , 2 , 1 , 1 ]),
308
+ stride_per_key_per_rank = torch .tensor ([[3 ], [2 ], [1 ]]),
309
+ inverse_indices = (
310
+ ["f1" , "f2" , "f3" ],
311
+ torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ], [0 , 0 , 0 ]]),
312
+ ),
313
+ )
314
+ kjt_2 = KeyedJaggedTensor (
315
+ keys = ["f1" , "f2" , "f3" ],
316
+ values = torch .tensor ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ]),
317
+ lengths = torch .tensor ([1 , 2 , 3 , 2 , 1 , 1 ]),
318
+ stride_per_key_per_rank = torch .tensor ([[1 ], [2 ], [3 ]]),
319
+ inverse_indices = (
320
+ ["f1" , "f2" , "f3" ],
321
+ torch .tensor ([[0 , 0 , 0 ], [0 , 1 , 0 ], [0 , 1 , 2 ]]),
322
+ ),
305
323
)
306
324
307
- eager_out = model (id_list_features )
325
+ eager_out = model (kjt_1 )
326
+ eager_out_2 = model (kjt_2 )
308
327
309
328
# Serialize EBC
310
329
model , sparse_fqns = encapsulate_ir_modules (model , JsonSerializer )
311
330
ep = torch .export .export (
312
331
model ,
313
- (id_list_features ,),
332
+ (kjt_1 ,),
314
333
{},
315
334
strict = False ,
316
335
# Allows KJT to not be unflattened and run a forward on unflattened EP
317
336
preserve_module_call_signature = (tuple (sparse_fqns )),
318
337
)
319
338
320
339
# Run forward on ExportedProgram
321
- ep_output = ep .module ()(id_list_features )
340
+ ep_output = ep .module ()(kjt_1 )
341
+ ep_output_2 = ep .module ()(kjt_2 )
322
342
343
+ self .assertEqual (len (ep_output ), len (kjt_1 .keys ()))
344
+ self .assertEqual (len (ep_output_2 ), len (kjt_2 .keys ()))
323
345
for i , tensor in enumerate (ep_output ):
324
- self .assertEqual (eager_out [i ].shape , tensor .shape )
346
+ self .assertEqual (eager_out [i ].shape [1 ], tensor .shape [1 ])
347
+ for i , tensor in enumerate (ep_output_2 ):
348
+ self .assertEqual (eager_out_2 [i ].shape [1 ], tensor .shape [1 ])
325
349
326
350
# Deserialize EBC
327
351
unflatten_ep = torch .export .unflatten (ep )
328
352
deserialized_model = decapsulate_ir_modules (unflatten_ep , JsonSerializer )
329
353
330
354
# check EBC config
331
- for i in range (5 ):
355
+ for i in range (1 ):
332
356
ebc_name = f"ebc{ i + 1 } "
333
357
self .assertIsInstance (
334
358
getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
@@ -343,36 +367,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
343
367
self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
344
368
self .assertEqual (deserialized .feature_names , orginal .feature_names )
345
369
346
- # check FPEBC config
347
- for i in range (2 ):
348
- fpebc_name = f"fpebc{ i + 1 } "
349
- assert isinstance (
350
- getattr (deserialized_model , fpebc_name ),
351
- FeatureProcessedEmbeddingBagCollection ,
352
- )
353
-
354
- for deserialized , orginal in zip (
355
- getattr (
356
- deserialized_model , fpebc_name
357
- )._embedding_bag_collection .embedding_bag_configs (),
358
- getattr (
359
- model , fpebc_name
360
- )._embedding_bag_collection .embedding_bag_configs (),
361
- ):
362
- self .assertEqual (deserialized .name , orginal .name )
363
- self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
364
- self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
365
- self .assertEqual (deserialized .feature_names , orginal .feature_names )
366
-
367
370
# Run forward on deserialized model and compare the output
368
371
deserialized_model .load_state_dict (model .state_dict ())
369
- deserialized_out = deserialized_model (id_list_features )
372
+ deserialized_out = deserialized_model (kjt_1 )
370
373
371
374
self .assertEqual (len (deserialized_out ), len (eager_out ))
372
375
for deserialized , orginal in zip (deserialized_out , eager_out ):
373
376
self .assertEqual (deserialized .shape , orginal .shape )
374
377
self .assertTrue (torch .allclose (deserialized , orginal ))
375
378
379
+ deserialized_out_2 = deserialized_model (kjt_2 )
380
+
381
+ self .assertEqual (len (deserialized_out_2 ), len (eager_out_2 ))
382
+ for deserialized , orginal in zip (deserialized_out_2 , eager_out_2 ):
383
+ self .assertEqual (deserialized .shape , orginal .shape )
384
+ self .assertTrue (torch .allclose (deserialized , orginal ))
385
+
376
386
def test_dynamic_shape_ebc_disabled_in_oss_compatibility (self ) -> None :
377
387
model = self .generate_model ()
378
388
feature1 = KeyedJaggedTensor .from_offsets_sync (
0 commit comments