diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 169ae010aa98..9090ae9707c8 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -333,21 +333,28 @@ inline Tensor sum(const Tensor& data, const Array& axis, bool keepdims } inline Tensor collapse_sum(const Tensor& data, Array target_shape) { - ICHECK_GE(data->shape.size(), target_shape.size()); - auto ishape = detail::GetConstIntValues(data->shape, "ishape"); - auto oshape = detail::GetConstIntValues(target_shape, "oshape"); + const auto& ishape = data->shape; + const auto& oshape = target_shape; + int isize = data->shape.size(); + int osize = target_shape.size(); + + ICHECK_GE(isize, osize) + << "Invalid collapse: input dimensionality smaller than output dimensionality.\ninput shape: " + << data->shape << "\nvs\noutput shape: " << target_shape; std::vector reduce_axes; std::vector squeeze_axes; - for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { - if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) { + tvm::PrimExpr one(1); + + for (int i_ax = isize - 1, o_ax = osize - 1; i_ax >= 0; --i_ax) { + if (o_ax >= 0 && topi::detail::EqualCheck(ishape[i_ax], oshape[o_ax])) { --o_ax; continue; } reduce_axes.push_back(i_ax); if (o_ax < 0) { // squeeze o_ax if was added during expansion squeeze_axes.push_back(i_ax); - } else if (oshape[o_ax] == 1) { + } else if (topi::detail::EqualCheck(one, oshape[o_ax])) { --o_ax; } } diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index 45d07af577a3..5045cb817457 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -248,3 +248,34 @@ def prod(data, axis=None, keepdims=False): ret : tvm.te.Tensor """ return cpp.prod(data, axis, keepdims) + + +def collapse_sum(data, target_shape): + """Return a summation of data to the given shape. + + collapse_sum is intended as the backward operator of topi broadcast operators in the automatic + differentiation process. + + We expect that data is the result of broadcasting some tensor of target_shape in some + broadcast operation. Thus target_shape and data.shape must follow broadcast rules. + + During computation, the axes of data.shape and target_shape are checked from right to left. + For every axis, if it either: + - exist in data but not in target_shape, or + - is larger than 1 in data and equals to 1 in target_shape, + data will be summed over this axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input tensor. + + shape : Tuple[int] + The shape to collapse to. + + Returns + ------- + ret : tvm.te.Tensor + The result tensor after summation. + """ + return cpp.collapse_sum(data, target_shape) diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 3d1c6f9f7d5b..a9d692cc0752 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); }); +TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::collapse_sum(args[0], args[1]); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py index 3c4c170d0dd9..8e45ae9a6eab 100644 --- a/tests/python/topi/python/test_topi_reduce.py +++ b/tests/python/topi/python/test_topi_reduce.py @@ -25,7 +25,7 @@ import tvm.testing import tvm.topi.testing -from tvm import te, topi +from tvm import te, topi, tir in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters( ((32,), 0, False, "argmax", "float32"), @@ -191,5 +191,53 @@ def test_complex_reduce(target, dev): tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) +n = tir.Var("n", "int32") +m = tir.Var("m", "int32") +true_value_map = {n: 3, m: 5} + +data_shape, target_shape = tvm.testing.parameters( + ((2, 3), (3,)), + ((2, 3, 4), (2, 1, 4)), + ((2, 3, 4, 5), (3, 1, 5)), + ((2, n, 4, m), (n, 1, m)), +) + + +def _my_npy_collapse_sum(data, target_shape): + reduce_axes = [] + i = data.ndim - 1 + j = len(target_shape) - 1 + while i >= 0: + if j < 0: + reduce_axes.append(i) + elif target_shape[j] == 1 and data.shape[i] > 1: + reduce_axes.append(i) + i -= 1 + j -= 1 + return np.sum(data, tuple(reduce_axes)).reshape(target_shape) + + +def test_collapse_sum(data_shape, target_shape): + A = te.placeholder(data_shape, name="A") + B = topi.collapse_sum(A, target_shape) + s = te.create_schedule([B.op]) + + data_shape_const = [int(s) if s not in true_value_map else true_value_map[s] for s in A.shape] + target_shape_const = [ + int(s) if s not in true_value_map else true_value_map[s] for s in target_shape + ] + a_np = np.random.uniform(size=data_shape_const).astype(A.dtype) + b_np = _my_npy_collapse_sum(a_np, target_shape_const) + dev = tvm.cpu(0) + a = tvm.nd.array(a_np, dev) + B_shape_const = [int(s) if s not in true_value_map else true_value_map[s] for s in B.shape] + b = tvm.nd.array(np.zeros(B_shape_const, dtype=B.dtype), dev) + # Building with the CSE pass disabled + with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): + foo = tvm.build(s, [A, B], "llvm", name="collapse_sum") + foo(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + + if __name__ == "__main__": tvm.testing.main()