Skip to content

Commit d291621

Browse files
authored
*_pool*d: improve error handling and error messages. (#9648)
This PR refactors `{avg,max}_pool{2,3}d` operation implementations by improving their error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::avg_pool_nd{,_backward}` return `StatusOr<XLATensorPtr>` - Make `tensor_methods::max_pool_nd{,_backward}` return `StatusOr<std::tuple<XLATensorPtr, XLATensorPtr>>` - Improve error messages and error handling - Remove `CheckIntList` - Create the following new functions: - `RepeatIfSingleElement(span, n)`: if `span` is a single-element list, create a new one repeating it `n` times. Otherwise return the elements in `span`. - ` CheckPoolNdInputHasSize(...)`: check that the given list has a specific size - ` FillAndCheckPoolNdInputs(...)`: runs the 2 functions above for `*_pool*d` common inputs, i.e. `kernel_size`, `stride`, and `padding`
1 parent 834d29e commit d291621

File tree

6 files changed

+302
-167
lines changed

6 files changed

+302
-167
lines changed

test/cpp/test_tensor.cpp

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <limits>
55
#include <vector>
66

7+
#include "absl/base/nullability.h"
78
#include "test/cpp/cpp_test_util.h"
89
#include "test/cpp/torch_xla_test.h"
910
#include "torch/csrc/autograd/variable.h"
@@ -297,14 +298,18 @@ TEST_F(TensorTest, TestMaxPool2D) {
297298
/*padding=*/{padding, padding}, /*dilation=*/{1, 1},
298299
/*ceil_mode=*/false);
299300
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
300-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
301+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
301302
XLATensor::Create(input, device));
302-
auto dev_output = tensor_methods::max_pool_nd(
303-
dev_input,
304-
/*spatial_dim_count=*/2,
305-
/*kernel_size=*/{kernel_size, kernel_size},
306-
/*stride=*/{stride, stride},
307-
/*padding=*/{padding, padding}, /*ceil_mode=*/false);
303+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr>
304+
dev_output;
305+
XLA_ASSIGN_OR_THROW(
306+
dev_output,
307+
tensor_methods::max_pool_nd(
308+
dev_input,
309+
/*spatial_dim_count=*/2,
310+
/*kernel_size=*/{kernel_size, kernel_size},
311+
/*stride=*/{stride, stride},
312+
/*padding=*/{padding, padding}, /*ceil_mode=*/false));
308313
AllClose(output, std::get<0>(dev_output));
309314
});
310315
}
@@ -322,15 +327,18 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) {
322327
/*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1},
323328
/*ceil_mode=*/false);
324329
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
325-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
330+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
326331
XLATensor::Create(input, device));
327-
auto dev_output = tensor_methods::max_pool_nd(
328-
dev_input,
329-
/*spatial_dim_count=*/2,
330-
/*kernel_size=*/{kernel_size, kernel_size + 1},
331-
/*stride=*/{stride, stride + 1},
332-
/*padding=*/{padding, padding + 1},
333-
/*ceil_mode=*/false);
332+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr>
333+
dev_output;
334+
XLA_ASSIGN_OR_THROW(dev_output,
335+
tensor_methods::max_pool_nd(
336+
dev_input,
337+
/*spatial_dim_count=*/2,
338+
/*kernel_size=*/{kernel_size, kernel_size + 1},
339+
/*stride=*/{stride, stride + 1},
340+
/*padding=*/{padding, padding + 1},
341+
/*ceil_mode=*/false));
334342
AllClose(output, std::get<0>(dev_output));
335343
});
336344
}
@@ -351,16 +359,17 @@ TEST_F(TensorTest, TestAvgPool2D) {
351359
/*ceil_mode=*/false, count_include_pad,
352360
/*divisor_override=*/std::nullopt);
353361
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
354-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
362+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
355363
XLATensor::Create(input, device));
356-
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
357-
dev_input,
358-
/*spatial_dim_count=*/2,
359-
/*kernel_size=*/{kernel_size, kernel_size},
360-
/*stride=*/{stride, stride},
361-
/*padding=*/{padding, padding},
362-
/*ceil_mode=*/false, count_include_pad,
363-
/*divisor_override=*/std::nullopt);
364+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_output,
365+
tensor_methods::avg_pool_nd(
366+
dev_input,
367+
/*spatial_dim_count=*/2,
368+
/*kernel_size=*/{kernel_size, kernel_size},
369+
/*stride=*/{stride, stride},
370+
/*padding=*/{padding, padding},
371+
/*ceil_mode=*/false, count_include_pad,
372+
/*divisor_override=*/std::nullopt));
364373
AllClose(output, dev_output);
365374
});
366375
}
@@ -382,17 +391,19 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
382391
/*count_include_pad=*/count_include_pad,
383392
/*divisor_override=*/std::nullopt);
384393
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
385-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
394+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
386395
XLATensor::Create(input, device));
387-
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
388-
dev_input,
389-
/*spatial_dim_count=*/2,
390-
/*kernel_size=*/{kernel_size, kernel_size + 1},
391-
/*stride=*/{stride, stride + 1},
392-
/*padding=*/{padding, padding + 1},
393-
/*ceil_mode=*/false,
394-
/*count_include_pad=*/count_include_pad,
395-
/*divisor_override=*/std::nullopt);
396+
XLA_ASSIGN_OR_THROW(
397+
absl_nonnull XLATensorPtr dev_output,
398+
tensor_methods::avg_pool_nd(
399+
dev_input,
400+
/*spatial_dim_count=*/2,
401+
/*kernel_size=*/{kernel_size, kernel_size + 1},
402+
/*stride=*/{stride, stride + 1},
403+
/*padding=*/{padding, padding + 1},
404+
/*ceil_mode=*/false,
405+
/*count_include_pad=*/count_include_pad,
406+
/*divisor_override=*/std::nullopt));
396407
AllClose(output, dev_output);
397408
});
398409
}

test/test_ops_error_message.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,32 @@ def test():
331331
callable=test,
332332
expect="""uniform_(): expected `from` (5) <= `to` (2).""")
333333

334+
def test_avg_pool_3d_raises_error_on_bad_spec(self):
335+
device = torch_xla.device()
336+
a = torch.rand(1, 1, 4, 4, 4, device=device)
337+
338+
def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
339+
return lambda: torch.nn.functional.avg_pool3d(a, kernel_size, stride,
340+
padding)
341+
342+
self.assertExpectedRaisesInline(
343+
exc_type=RuntimeError,
344+
callable=gen_test_fn(kernel_size=[2, 2]),
345+
expect="""avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3."""
346+
)
347+
348+
self.assertExpectedRaisesInline(
349+
exc_type=RuntimeError,
350+
callable=gen_test_fn(stride=[1, 2]),
351+
expect="""avg_pool3d(): expected argument stride [1, 2] (size: 2) to have size of 3."""
352+
)
353+
354+
self.assertExpectedRaisesInline(
355+
exc_type=RuntimeError,
356+
callable=gen_test_fn(padding=[1, 2]),
357+
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
358+
)
359+
334360

335361
if __name__ == "__main__":
336362
unittest.main()

torch_xla/csrc/aten_autograd_ops.cpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,13 @@ torch::Tensor MaxPool3dAutogradFunction::forward(
192192
return std::get<0>(results);
193193
}
194194
ctx->save_for_backward({self});
195-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
196-
auto outputs = tensor_methods::max_pool_nd(
197-
xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
198-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
199-
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
195+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
196+
bridge::GetXlaTensor(self));
197+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
198+
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
199+
xla_self, /*spatial_dim_count=*/3,
200+
kernel_size, stride, padding, ceil_mode));
201+
return bridge::AtenFromXlaTensor(std::get<0>(output));
200202
}
201203

202204
torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
@@ -220,13 +222,15 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
220222
padding, dilation,
221223
ceil_mode, indices);
222224
}
223-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0,
225+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output_0,
224226
bridge::GetXlaTensor(grad_output[0]));
225-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
226-
grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
227-
xla_grad_output_0, xla_self, /*spatial_dim_count=*/3,
228-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
229-
XlaHelpers::I64List(padding), ceil_mode));
227+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
228+
bridge::GetXlaTensor(self));
229+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
230+
tensor_methods::max_pool_nd_backward(
231+
xla_grad_output_0, xla_self, /*spatial_dim_count=*/3,
232+
kernel_size, stride, padding, ceil_mode));
233+
grad = bridge::AtenFromXlaTensor(std::move(output));
230234

231235
torch::Tensor undef;
232236
torch::autograd::variable_list grad_inputs = {grad, undef, undef,
@@ -239,24 +243,28 @@ torch::Tensor max_pool2d_forward(torch::Tensor self,
239243
torch::IntArrayRef stride,
240244
torch::IntArrayRef padding,
241245
torch::IntArrayRef dilation, bool ceil_mode) {
242-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
243-
auto outputs = tensor_methods::max_pool_nd(
244-
xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
245-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
246-
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
246+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
247+
bridge::GetXlaTensor(self));
248+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
249+
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
250+
xla_self, /*spatial_dim_count=*/2,
251+
kernel_size, stride, padding, ceil_mode));
252+
return bridge::AtenFromXlaTensor(std::get<0>(output));
247253
}
248254

249255
torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
250256
torch::IntArrayRef kernel_size,
251257
torch::IntArrayRef stride,
252258
torch::IntArrayRef padding, bool ceil_mode) {
253-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
259+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output,
254260
bridge::GetXlaTensor(grad_output));
255-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
256-
auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
257-
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
258-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
259-
XlaHelpers::I64List(padding), ceil_mode));
261+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
262+
bridge::GetXlaTensor(self));
263+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
264+
tensor_methods::max_pool_nd_backward(
265+
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
266+
kernel_size, stride, padding, ceil_mode));
267+
auto grad = bridge::AtenFromXlaTensor(std::move(output));
260268
return grad;
261269
}
262270

0 commit comments

Comments
 (0)