Skip to content

Commit f907e34

Browse files
qxy11facebook-github-bot
authored andcommitted
Move sparse_op registration + correct sigmoid XL lowering settings (#4179)
Summary: Pull Request resolved: #4179 X-link: facebookresearch/FBGEMM#1257 This part was causing issues in unit tests due to double fake operator registration. Moved this up into setup() to register these fake ops only once. Reviewed By: kqfu Differential Revision: D75007509 fbshipit-source-id: 3518920e4267a7e34fee3e17134d5e781cfab975
1 parent d111fcb commit f907e34

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,30 @@ def permute_multi_embedding_function_impl_abstract(
11451145
return output
11461146

11471147

1148+
def lengths_range_abstract(
1149+
lengths: Tensor,
1150+
output_shape: Optional[Sequence[int]] = None,
1151+
) -> Tensor:
1152+
torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
1153+
output_size = 0
1154+
if output_shape is not None:
1155+
output_size = math.prod(output_shape)
1156+
else:
1157+
ctx = torch.library.get_ctx()
1158+
output_size = ctx.new_dynamic_size()
1159+
return lengths.new_empty([output_size], dtype=lengths.dtype)
1160+
1161+
1162+
def all_to_one_device(
1163+
input_tensors: List[Tensor],
1164+
target_device: torch.device,
1165+
) -> List[Tensor]:
1166+
return [
1167+
torch.empty_like(input_tensor, device=torch.device("meta"))
1168+
for input_tensor in input_tensors
1169+
]
1170+
1171+
11481172
def _setup() -> None:
11491173
# pyre-ignore[16]
11501174
_setup.done = getattr(_setup, "done", False)
@@ -1215,6 +1239,7 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
12151239
)
12161240
impl_abstract("fbgemm::segment_sum_csr", segment_sum_csr_abstract)
12171241
impl_abstract("fbgemm::dense_to_jagged_forward", dense_to_jagged_forward)
1242+
impl_abstract("fbgemm::all_to_one_device", all_to_one_device)
12181243
impl_abstract(
12191244
"fbgemm::batch_index_select_dim0", batch_index_select_dim0_abstract
12201245
)
@@ -1282,6 +1307,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
12821307
"fbgemm::generic_histogram_binning_calibration_by_feature",
12831308
generic_histogram_binning_calibration_by_feature,
12841309
)
1310+
impl_abstract(
1311+
"fbgemm::lengths_range",
1312+
lengths_range_abstract,
1313+
)
12851314
impl_abstract(
12861315
"fbgemm::permute_multi_embedding_function",
12871316
permute_multi_embedding_function_impl_abstract,
@@ -1330,29 +1359,3 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
13301359

13311360

13321361
_setup()
1333-
1334-
1335-
@torch.library.register_fake("fbgemm::lengths_range")
1336-
def lengths_range_abstract(
1337-
lengths: Tensor,
1338-
output_shape: Optional[Sequence[int]] = None,
1339-
) -> Tensor:
1340-
torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor")
1341-
output_size = 0
1342-
if output_shape is not None:
1343-
output_size = math.prod(output_shape)
1344-
else:
1345-
ctx = torch.library.get_ctx()
1346-
output_size = ctx.new_dynamic_size()
1347-
return lengths.new_empty([output_size], dtype=lengths.dtype)
1348-
1349-
1350-
@torch.library.register_fake("fbgemm::all_to_one_device")
1351-
def all_to_one_device(
1352-
input_tensors: List[Tensor],
1353-
target_device: torch.device,
1354-
) -> List[Tensor]:
1355-
return [
1356-
torch.empty_like(input_tensor, device=torch.device("meta"))
1357-
for input_tensor in input_tensors
1358-
]

0 commit comments

Comments
 (0)