@@ -1145,6 +1145,30 @@ def permute_multi_embedding_function_impl_abstract(
1145
1145
return output
1146
1146
1147
1147
1148
+ def lengths_range_abstract (
1149
+ lengths : Tensor ,
1150
+ output_shape : Optional [Sequence [int ]] = None ,
1151
+ ) -> Tensor :
1152
+ torch ._check (lengths .dim () == 1 , lambda : "lengths must be a 1D tensor" )
1153
+ output_size = 0
1154
+ if output_shape is not None :
1155
+ output_size = math .prod (output_shape )
1156
+ else :
1157
+ ctx = torch .library .get_ctx ()
1158
+ output_size = ctx .new_dynamic_size ()
1159
+ return lengths .new_empty ([output_size ], dtype = lengths .dtype )
1160
+
1161
+
1162
+ def all_to_one_device (
1163
+ input_tensors : List [Tensor ],
1164
+ target_device : torch .device ,
1165
+ ) -> List [Tensor ]:
1166
+ return [
1167
+ torch .empty_like (input_tensor , device = torch .device ("meta" ))
1168
+ for input_tensor in input_tensors
1169
+ ]
1170
+
1171
+
1148
1172
def _setup () -> None :
1149
1173
# pyre-ignore[16]
1150
1174
_setup .done = getattr (_setup , "done" , False )
@@ -1215,6 +1239,7 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
1215
1239
)
1216
1240
impl_abstract ("fbgemm::segment_sum_csr" , segment_sum_csr_abstract )
1217
1241
impl_abstract ("fbgemm::dense_to_jagged_forward" , dense_to_jagged_forward )
1242
+ impl_abstract ("fbgemm::all_to_one_device" , all_to_one_device )
1218
1243
impl_abstract (
1219
1244
"fbgemm::batch_index_select_dim0" , batch_index_select_dim0_abstract
1220
1245
)
@@ -1282,6 +1307,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
1282
1307
"fbgemm::generic_histogram_binning_calibration_by_feature" ,
1283
1308
generic_histogram_binning_calibration_by_feature ,
1284
1309
)
1310
+ impl_abstract (
1311
+ "fbgemm::lengths_range" ,
1312
+ lengths_range_abstract ,
1313
+ )
1285
1314
impl_abstract (
1286
1315
"fbgemm::permute_multi_embedding_function" ,
1287
1316
permute_multi_embedding_function_impl_abstract ,
@@ -1330,29 +1359,3 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
1330
1359
1331
1360
1332
1361
_setup ()
1333
-
1334
-
1335
- @torch .library .register_fake ("fbgemm::lengths_range" )
1336
- def lengths_range_abstract (
1337
- lengths : Tensor ,
1338
- output_shape : Optional [Sequence [int ]] = None ,
1339
- ) -> Tensor :
1340
- torch ._check (lengths .dim () == 1 , lambda : "lengths must be a 1D tensor" )
1341
- output_size = 0
1342
- if output_shape is not None :
1343
- output_size = math .prod (output_shape )
1344
- else :
1345
- ctx = torch .library .get_ctx ()
1346
- output_size = ctx .new_dynamic_size ()
1347
- return lengths .new_empty ([output_size ], dtype = lengths .dtype )
1348
-
1349
-
1350
- @torch .library .register_fake ("fbgemm::all_to_one_device" )
1351
- def all_to_one_device (
1352
- input_tensors : List [Tensor ],
1353
- target_device : torch .device ,
1354
- ) -> List [Tensor ]:
1355
- return [
1356
- torch .empty_like (input_tensor , device = torch .device ("meta" ))
1357
- for input_tensor in input_tensors
1358
- ]
0 commit comments