Skip to content

Commit d8bd415

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Pass in Compute kernel type in New plan generation (#3061)
Summary: Pull Request resolved: #3061 This is tricky - a small edge case that can't be reproduced locally - OSS environment will expose where compute kernel may be detected as "fused" when not passed in. This will cause the unit test to occasionally change the kernel type in resharding (due to the new plan passed in). Passing in the kernel type in the new plan generation fixes this. Reviewed By: aliafzal Differential Revision: D76233463 fbshipit-source-id: 740bfb6619df2ee862b3bf84f23dec395e14bb83
1 parent 861b00b commit d8bd415

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ def dynamic_sharding_test(
452452

453453
new_per_param_sharding = {}
454454

455+
assert len(sharders) == 1
456+
# pyre-ignore
457+
kernel_type = sharders[0]._kernel_type
455458
# Construct parameter shardings
456459
for i in range(num_tables):
457460
table_name = tables[i].name
@@ -466,7 +469,7 @@ def dynamic_sharding_test(
466469
)
467470
# TODO: CW sharding constructor takes in different args
468471
new_per_param_sharding[table_name] = sharding_type_constructor(
469-
rank=new_ranks[i][0]
472+
rank=new_ranks[i][0], compute_kernel=kernel_type
470473
)
471474

472475
new_module_sharding_plan = construct_module_sharding_plan(

0 commit comments

Comments
 (0)