Skip to content

Commit a90725a

Browse files
bdhirshfacebook-github-bot
authored andcommitted
remove pt2 compliant xfails for jagged ops
Summary: letting CI tell me what tests to run to fix these ops for pt2 Differential Revision: D85630006
1 parent ed0c4cb commit a90725a

File tree

6 files changed

+77
-19
lines changed

6 files changed

+77
-19
lines changed

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
#include "common.cuh"
1010

1111
FBGEMM_OP_DISPATCH(CUDA, "dense_to_jagged", fbgemm_gpu::dense_to_jagged);
12-
FBGEMM_OP_DISPATCH(
13-
CUDA,
14-
"jagged_to_padded_dense",
15-
fbgemm_gpu::jagged_to_padded_dense);
1612
FBGEMM_OP_DISPATCH(
1713
CUDA,
1814
"jagged_dense_elementwise_add",

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,14 +785,30 @@ class JaggedSliceOp : public torch::autograd::Function<JaggedSliceOp> {
785785
} // namespace
786786

787787
///@ingroup jagged-tensor-ops-cpu
788-
Tensor jagged_to_padded_dense(
788+
Tensor jagged_to_padded_dense_forward_autograd(
789789
const Tensor& values,
790790
const std::vector<Tensor>& offsets,
791791
const c10::SymIntArrayRef max_lengths,
792792
const double padding_value) {
793793
return JaggedToPaddedDenseOp::apply(
794794
values, offsets, max_lengths, padding_value)[0];
795795
}
796+
Tensor jagged_to_padded_dense(
797+
const Tensor& values,
798+
const std::vector<Tensor>& offsets,
799+
const c10::SymIntArrayRef max_lengths,
800+
const double padding_value) {
801+
static auto op =
802+
c10::Dispatcher::singleton()
803+
.findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "")
804+
.typed<at::Tensor(
805+
const Tensor& values,
806+
const std::vector<Tensor>& offsets,
807+
at::ArrayRef<at::SymInt> max_lengths,
808+
const double padding_value)>();
809+
Tensor output = op.call(values, offsets, max_lengths, padding_value);
810+
return output;
811+
}
796812

797813
///@ingroup jagged-tensor-ops-cpu
798814
/// Output = x + y where x is jagged, y and output are dense
@@ -973,8 +989,16 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
973989
m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm));
974990
m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm));
975991
m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice));
992+
m.impl(
993+
"jagged_to_padded_dense_forward",
994+
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_forward_autograd));
976995
}
977996

997+
// These ops are all custom autograd::Functions, which we are registering
998+
// to the Autograd key above.
999+
// The only reason that we *also* register to them to the CompositeImplicit
1000+
// key is so that they will decompose by default when using torch.export
1001+
// (even under inference_mode).
9781002
TORCH_LIBRARY_IMPL(fbgemm, CompositeImplicitAutograd, m) {
9791003
m.impl("jagged_index_select", TORCH_FN(fbgemm_gpu::jagged_index_select_2d));
9801004
m.impl("dense_to_jagged", TORCH_FN(fbgemm_gpu::dense_to_jagged));

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,10 +1821,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
18211821
DISPATCH_TO_CPU("dense_to_jagged", fbgemm_gpu::dense_to_jagged);
18221822
DISPATCH_TO_CPU(
18231823
"dense_to_jagged_forward", fbgemm_gpu::dense_to_jagged_forward);
1824-
DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense);
18251824
DISPATCH_TO_CPU(
18261825
"jagged_to_padded_dense_forward",
1827-
fbgemm_gpu::jagged_to_padded_dense_forward);
1826+
fbgemm_gpu::jagged_to_padded_dense_forward_cpu);
18281827
DISPATCH_TO_CPU(
18291828
"jagged_to_padded_dense_backward",
18301829
fbgemm_gpu::jagged_to_padded_dense_backward);

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,21 @@ Tensor jagged_to_padded_dense_meta(
5353

5454
Tensor jagged_to_padded_dense_backward_meta(
5555
const at::Tensor& grad_output,
56-
const std::vector<Tensor>& /*offsets*/,
56+
const std::vector<Tensor>& offsets,
5757
at::SymInt total_L) {
5858
const auto& grad_padded_values = grad_output;
5959

60-
at::SymInt D = grad_padded_values.sym_size(-1);
60+
const bool D_folded = grad_padded_values.dim() == offsets.size() + 1;
61+
const auto& grad_padded_values_view =
62+
D_folded ? grad_padded_values.unsqueeze(-1) : grad_padded_values;
63+
at::SymInt D = grad_padded_values_view.sym_size(-1);
6164
// Initialize with zeros so output will be zero for the portion truncated
6265
// in forward.
6366
auto grad_values =
6467
at::zeros_symint({std::move(total_L), D}, grad_padded_values.options());
6568

6669
TORCH_CHECK(grad_values.is_meta());
67-
return grad_values;
70+
return D_folded ? grad_values.squeeze(-1) : grad_values;
6871
}
6972

7073
Tensor jagged_dense_dense_elementwise_add_jagged_output_forward_meta(

fbgemm_gpu/test/jagged/common.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,7 @@
4343
# Please avoid putting tests here, you should put operator-specific
4444
# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json
4545
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
46-
additional_decorators: dict[str, list[Callable]] = {
47-
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [
48-
# This operator has been grandfathered in. We need to fix this test failure.
49-
unittest.expectedFailure,
50-
],
51-
"test_pt2_compliant_tag_fbgemm_jagged_to_padded_dense": [
52-
unittest.expectedFailure,
53-
],
54-
}
46+
additional_decorators: dict[str, list[Callable]] = {}
5547

5648

5749
def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:

fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,50 @@ def test_jagged_to_padded_dense(
113113
rtol=1e-3,
114114
)
115115

116+
class Mod(torch.nn.Module):
117+
def __init__(self):
118+
super().__init__()
119+
120+
def forward(self, a, b, c, d):
121+
return torch.ops.fbgemm.jagged_to_padded_dense(a, b, c, d)
122+
123+
with torch.inference_mode():
124+
gm = torch.export.export(
125+
Mod(),
126+
(
127+
x_values.float().requires_grad_(True),
128+
x_offsets,
129+
max_lengths.astype(int).tolist(),
130+
padding_value,
131+
),
132+
).run_decompositions()
133+
num_fw_ops = len(
134+
[
135+
x
136+
for x in gm.graph.nodes
137+
if x.target is torch.ops.fbgemm.jagged_to_padded_dense_forward.default
138+
]
139+
)
140+
num_composite_ops = len(
141+
[
142+
x
143+
for x in gm.graph.nodes
144+
if x.target is torch.ops.fbgemm.jagged_to_padded_dense.default
145+
]
146+
)
147+
self.assertEqual(num_fw_ops, 1)
148+
self.assertEqual(num_composite_ops, 0)
149+
150+
torch.library.opcheck(
151+
torch.ops.fbgemm.jagged_to_padded_dense,
152+
(
153+
x_values.float().requires_grad_(True),
154+
x_offsets,
155+
max_lengths,
156+
padding_value,
157+
),
158+
)
159+
116160
@given(
117161
num_jagged_dim=st.integers(1, 5),
118162
outer_dense_size=st.integers(0, 5),

0 commit comments

Comments
 (0)