@@ -48,6 +48,8 @@ class JaggedToPaddedDenseOp
4848 const std::vector<Tensor>& offsets,
4949 at::ArrayRef<at::SymInt> max_lengths,
5050 const double padding_value)>();
51+
52+ at::AutoDispatchBelowAutograd mode;
5153 Tensor padded_values = op.call (values, offsets, max_lengths, padding_value);
5254
5355 return {padded_values};
@@ -286,6 +288,7 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
286288 const Tensor& dense,
287289 const std::vector<Tensor>& offsets,
288290 std::optional<at::SymInt> total_L)>();
291+ at::AutoDispatchBelowAutograd mode;
289292 auto output = op.call (dense, offsets, total_L);
290293
291294 return {output};
@@ -785,14 +788,30 @@ class JaggedSliceOp : public torch::autograd::Function<JaggedSliceOp> {
785788} // namespace
786789
787790// /@ingroup jagged-tensor-ops-cpu
788- Tensor jagged_to_padded_dense (
791+ Tensor jagged_to_padded_dense_forward_autograd (
789792 const Tensor& values,
790793 const std::vector<Tensor>& offsets,
791794 const c10::SymIntArrayRef max_lengths,
792795 const double padding_value) {
793796 return JaggedToPaddedDenseOp::apply (
794797 values, offsets, max_lengths, padding_value)[0 ];
795798}
799+ Tensor jagged_to_padded_dense (
800+ const Tensor& values,
801+ const std::vector<Tensor>& offsets,
802+ const c10::SymIntArrayRef max_lengths,
803+ const double padding_value) {
804+ static auto op =
805+ c10::Dispatcher::singleton ()
806+ .findSchemaOrThrow (" fbgemm::jagged_to_padded_dense_forward" , " " )
807+ .typed <at::Tensor (
808+ const Tensor& values,
809+ const std::vector<Tensor>& offsets,
810+ at::ArrayRef<at::SymInt> max_lengths,
811+ const double padding_value)>();
812+ Tensor output = op.call (values, offsets, max_lengths, padding_value);
813+ return output;
814+ }
796815
797816// /@ingroup jagged-tensor-ops-cpu
798817// / Output = x + y where x is jagged, y and output are dense
@@ -855,7 +874,20 @@ std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged(
855874 const Tensor& dense,
856875 const std::vector<Tensor>& offsets,
857876 std::optional<at::SymInt> total_L) {
858- return {DenseToJaggedOp::apply (dense, offsets, total_L)[0 ], offsets};
877+ static auto op = c10::Dispatcher::singleton ()
878+ .findSchemaOrThrow (" fbgemm::dense_to_jagged_forward" , " " )
879+ .typed <Tensor (
880+ const Tensor& dense,
881+ const std::vector<Tensor>& offsets,
882+ std::optional<at::SymInt> total_L)>();
883+ auto output = op.call (dense, offsets, total_L);
884+ return {output, offsets};
885+ }
886+ Tensor dense_to_jagged_forward_autograd (
887+ const Tensor& dense,
888+ const std::vector<Tensor>& offsets,
889+ std::optional<at::SymInt> total_L) {
890+ return DenseToJaggedOp::apply (dense, offsets, total_L)[0 ];
859891}
860892
861893// /@ingroup jagged-tensor-ops-cpu
@@ -973,6 +1005,12 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
9731005 m.impl (" jagged_jagged_bmm" , TORCH_FN (fbgemm_gpu::jagged_jagged_bmm));
9741006 m.impl (" jagged_dense_bmm" , TORCH_FN (fbgemm_gpu::jagged_dense_bmm));
9751007 m.impl (" jagged_slice" , TORCH_FN (fbgemm_gpu::jagged_slice));
1008+ m.impl (
1009+ " jagged_to_padded_dense_forward" ,
1010+ TORCH_FN (fbgemm_gpu::jagged_to_padded_dense_forward_autograd));
1011+ m.impl (
1012+ " dense_to_jagged_forward" ,
1013+ TORCH_FN (fbgemm_gpu::dense_to_jagged_forward_autograd));
9761014}
9771015
9781016TORCH_LIBRARY_IMPL (fbgemm, CompositeImplicitAutograd, m) {
0 commit comments