Skip to content

Commit f46faaa

Browse files
bdhirshfacebook-github-bot
authored andcommitted
remove pt2 compliant xfails for jagged ops (#5068)
Summary: X-link: facebookresearch/FBGEMM#2075 letting CI tell me what tests to run to fix these ops for pt2 Differential Revision: D85630006
1 parent ed0c4cb commit f46faaa

File tree

8 files changed

+117
-22
lines changed

8 files changed

+117
-22
lines changed

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88

99
#include "common.cuh"
1010

11-
FBGEMM_OP_DISPATCH(CUDA, "dense_to_jagged", fbgemm_gpu::dense_to_jagged);
12-
FBGEMM_OP_DISPATCH(
13-
CUDA,
14-
"jagged_to_padded_dense",
15-
fbgemm_gpu::jagged_to_padded_dense);
1611
FBGEMM_OP_DISPATCH(
1712
CUDA,
1813
"jagged_dense_elementwise_add",

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class JaggedToPaddedDenseOp
4848
const std::vector<Tensor>& offsets,
4949
at::ArrayRef<at::SymInt> max_lengths,
5050
const double padding_value)>();
51+
52+
at::AutoDispatchBelowAutograd mode;
5153
Tensor padded_values = op.call(values, offsets, max_lengths, padding_value);
5254

5355
return {padded_values};
@@ -286,6 +288,7 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
286288
const Tensor& dense,
287289
const std::vector<Tensor>& offsets,
288290
std::optional<at::SymInt> total_L)>();
291+
at::AutoDispatchBelowAutograd mode;
289292
auto output = op.call(dense, offsets, total_L);
290293

291294
return {output};
@@ -785,14 +788,30 @@ class JaggedSliceOp : public torch::autograd::Function<JaggedSliceOp> {
785788
} // namespace
786789

787790
///@ingroup jagged-tensor-ops-cpu
788-
Tensor jagged_to_padded_dense(
791+
Tensor jagged_to_padded_dense_forward_autograd(
789792
const Tensor& values,
790793
const std::vector<Tensor>& offsets,
791794
const c10::SymIntArrayRef max_lengths,
792795
const double padding_value) {
793796
return JaggedToPaddedDenseOp::apply(
794797
values, offsets, max_lengths, padding_value)[0];
795798
}
799+
Tensor jagged_to_padded_dense(
800+
const Tensor& values,
801+
const std::vector<Tensor>& offsets,
802+
const c10::SymIntArrayRef max_lengths,
803+
const double padding_value) {
804+
static auto op =
805+
c10::Dispatcher::singleton()
806+
.findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "")
807+
.typed<at::Tensor(
808+
const Tensor& values,
809+
const std::vector<Tensor>& offsets,
810+
at::ArrayRef<at::SymInt> max_lengths,
811+
const double padding_value)>();
812+
Tensor output = op.call(values, offsets, max_lengths, padding_value);
813+
return output;
814+
}
796815

797816
///@ingroup jagged-tensor-ops-cpu
798817
/// Output = x + y where x is jagged, y and output are dense
@@ -855,7 +874,20 @@ std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged(
855874
const Tensor& dense,
856875
const std::vector<Tensor>& offsets,
857876
std::optional<at::SymInt> total_L) {
858-
return {DenseToJaggedOp::apply(dense, offsets, total_L)[0], offsets};
877+
static auto op = c10::Dispatcher::singleton()
878+
.findSchemaOrThrow("fbgemm::dense_to_jagged_forward", "")
879+
.typed<Tensor(
880+
const Tensor& dense,
881+
const std::vector<Tensor>& offsets,
882+
std::optional<at::SymInt> total_L)>();
883+
auto output = op.call(dense, offsets, total_L);
884+
return {output, offsets};
885+
}
886+
Tensor dense_to_jagged_forward_autograd(
887+
const Tensor& dense,
888+
const std::vector<Tensor>& offsets,
889+
std::optional<at::SymInt> total_L) {
890+
return DenseToJaggedOp::apply(dense, offsets, total_L)[0];
859891
}
860892

861893
///@ingroup jagged-tensor-ops-cpu
@@ -973,6 +1005,12 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
9731005
m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm));
9741006
m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm));
9751007
m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice));
1008+
m.impl(
1009+
"jagged_to_padded_dense_forward",
1010+
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_forward_autograd));
1011+
m.impl(
1012+
"dense_to_jagged_forward",
1013+
TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_autograd));
9761014
}
9771015

9781016
TORCH_LIBRARY_IMPL(fbgemm, CompositeImplicitAutograd, m) {

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,13 +1818,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
18181818
TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
18191819
DISPATCH_TO_CPU("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense);
18201820
DISPATCH_TO_CPU("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense);
1821-
DISPATCH_TO_CPU("dense_to_jagged", fbgemm_gpu::dense_to_jagged);
18221821
DISPATCH_TO_CPU(
18231822
"dense_to_jagged_forward", fbgemm_gpu::dense_to_jagged_forward);
1824-
DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense);
18251823
DISPATCH_TO_CPU(
18261824
"jagged_to_padded_dense_forward",
1827-
fbgemm_gpu::jagged_to_padded_dense_forward);
1825+
fbgemm_gpu::jagged_to_padded_dense_forward_cpu);
18281826
DISPATCH_TO_CPU(
18291827
"jagged_to_padded_dense_backward",
18301828
fbgemm_gpu::jagged_to_padded_dense_backward);

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,21 @@ Tensor jagged_to_padded_dense_meta(
5353

5454
Tensor jagged_to_padded_dense_backward_meta(
5555
const at::Tensor& grad_output,
56-
const std::vector<Tensor>& /*offsets*/,
56+
const std::vector<Tensor>& offsets,
5757
at::SymInt total_L) {
5858
const auto& grad_padded_values = grad_output;
5959

60-
at::SymInt D = grad_padded_values.sym_size(-1);
60+
const bool D_folded = grad_padded_values.dim() == offsets.size() + 1;
61+
const auto& grad_padded_values_view =
62+
D_folded ? grad_padded_values.unsqueeze(-1) : grad_padded_values;
63+
at::SymInt D = grad_padded_values_view.sym_size(-1);
6164
// Initialize with zeros so output will be zero for the portion truncated
6265
// in forward.
6366
auto grad_values =
6467
at::zeros_symint({std::move(total_L), D}, grad_padded_values.options());
6568

6669
TORCH_CHECK(grad_values.is_meta());
67-
return grad_values;
70+
return D_folded ? grad_values.squeeze(-1) : grad_values;
6871
}
6972

7073
Tensor jagged_dense_dense_elementwise_add_jagged_output_forward_meta(

fbgemm_gpu/test/jagged/common.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,7 @@
4343
# Please avoid putting tests here, you should put operator-specific
4444
# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json
4545
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
46-
additional_decorators: dict[str, list[Callable]] = {
47-
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [
48-
# This operator has been grandfathered in. We need to fix this test failure.
49-
unittest.expectedFailure,
50-
],
51-
"test_pt2_compliant_tag_fbgemm_jagged_to_padded_dense": [
52-
unittest.expectedFailure,
53-
],
54-
}
46+
additional_decorators: dict[str, list[Callable]] = {}
5547

5648

5749
def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:

fbgemm_gpu/test/jagged/dense_to_jagged_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def _test_dense_to_jagged(
8080
jagged_values.backward(ref_output_values)
8181
torch.testing.assert_close(dense.grad, ref_values)
8282

83+
torch.library.opcheck(
84+
torch.ops.fbgemm.dense_to_jagged,
85+
(dense.detach().requires_grad_(True), offsets),
86+
)
87+
8388
@given(
8489
num_jagged_dim=st.integers(1, 5),
8590
outer_dense_size=st.integers(0, 5),

fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,26 @@ def test_jagged_index_select_2d(
158158
rtol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None,
159159
atol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None,
160160
)
161+
if known_shape:
162+
with torch.no_grad():
163+
tmp_output, _ = torch.ops.fbgemm.jagged_index_select(
164+
values, lengths, indices
165+
)
166+
num_dense_output_rows = tmp_output.shape[0]
167+
torch.library.opcheck(
168+
torch.ops.fbgemm.jagged_index_select.default,
169+
(
170+
values.detach().requires_grad_(),
171+
lengths,
172+
indices,
173+
num_dense_output_rows,
174+
),
175+
)
176+
else:
177+
torch.library.opcheck(
178+
torch.ops.fbgemm.jagged_index_select.default,
179+
(values.detach().requires_grad_(), lengths, indices),
180+
)
161181

162182
@given(
163183
max_seq_length=st.integers(5, 10),

fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,50 @@ def test_jagged_to_padded_dense(
113113
rtol=1e-3,
114114
)
115115

116+
class Mod(torch.nn.Module):
117+
def __init__(self):
118+
super().__init__()
119+
120+
def forward(self, a, b, c, d):
121+
return torch.ops.fbgemm.jagged_to_padded_dense(a, b, c, d)
122+
123+
with torch.inference_mode():
124+
gm = torch.export.export(
125+
Mod(),
126+
(
127+
x_values.float().requires_grad_(True),
128+
x_offsets,
129+
max_lengths.astype(int).tolist(),
130+
padding_value,
131+
),
132+
).run_decompositions()
133+
num_fw_ops = len(
134+
[
135+
x
136+
for x in gm.graph.nodes
137+
if x.target is torch.ops.fbgemm.jagged_to_padded_dense_forward.default
138+
]
139+
)
140+
num_composite_ops = len(
141+
[
142+
x
143+
for x in gm.graph.nodes
144+
if x.target is torch.ops.fbgemm.jagged_to_padded_dense.default
145+
]
146+
)
147+
self.assertEqual(num_fw_ops, 1)
148+
self.assertEqual(num_composite_ops, 0)
149+
150+
torch.library.opcheck(
151+
torch.ops.fbgemm.jagged_to_padded_dense,
152+
(
153+
x_values.float().requires_grad_(True),
154+
x_offsets,
155+
max_lengths,
156+
padding_value,
157+
),
158+
)
159+
116160
@given(
117161
num_jagged_dim=st.integers(1, 5),
118162
outer_dense_size=st.integers(0, 5),

0 commit comments

Comments
 (0)