From 28c85f0dc24e41d14ebd5f03ad2a021a99290096 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 14 May 2023 17:20:19 -0400 Subject: [PATCH] [METAL] Fix vectorized select (#14846) This PR fixes the codegen for vectorized select in metal. Also enhances arithmetics to cover better constant bound. --- src/arith/const_int_bound.cc | 2 + src/target/source/codegen_metal.cc | 5 ++ src/target/source/codegen_metal.h | 1 + src/tir/transforms/flatten_buffer.cc | 2 + src/tir/transforms/lower_intrin.cc | 10 ++-- .../unittest/test_arith_const_int_bound.py | 18 ++++++++ .../unittest/test_target_codegen_metal.py | 46 ++++++++++++++----- 7 files changed, 69 insertions(+), 15 deletions(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 4c048bfa73ef..68ade3bb5400 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -177,6 +177,8 @@ class ConstIntBoundAnalyzer::Impl return Union(a, b); } + Entry VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); } + Entry VisitExpr_(const CastNode* op) final { Entry a; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 44da240dd5b0..9288c94e3df3 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -289,6 +289,11 @@ void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) } } +void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) + os << "select(" << PrintExpr(op->false_value) << ", " << PrintExpr(op->true_value) << ", " + << PrintExpr(op->condition) << ")"; +} + void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); PrintType(op->dtype, os); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 25643896093a..36be10d16363 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -50,6 +50,7 @@ class CodeGenMetal final : public CodeGenC { // print store of single element. void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor + void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 5a248dfbc311..933d5eeefb56 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -51,6 +51,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { } private: + using IRMutatorWithAnalyzer::VisitExpr; + using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt; using IRMutatorWithAnalyzer::VisitStmt_; diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 4cffe2a19d60..212ccf6e5616 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -119,7 +119,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // in terms of truncdiv using only positive operands. arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); if (const_int_bound->min_value < 0 && - const_int_bound->min_value > -(Downcast(tvm::max_value(op->a->dtype))->value)) { + const_int_bound->min_value > + -(Downcast(tvm::max_value(op->a->dtype.element_of()))->value)) { // The goal is to write floordiv(a,b) in terms of truncdiv, without using // negative operands. // @@ -150,7 +151,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // floordiv(a,b) // == floordiv(a + b*c, b) - c // == truncdiv(a + b*c, b) - c - IntImm min(op->a->dtype, const_int_bound->min_value); + IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); return truncdiv(offset_numerator, op->b) - ceildiv; @@ -214,7 +215,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // in terms of truncmod using only positive operands. arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); if (const_int_bound->min_value < 0 && - const_int_bound->min_value > -(Downcast(tvm::max_value(op->a->dtype))->value)) { + const_int_bound->min_value > + -(Downcast(tvm::max_value(op->a->dtype.element_of()))->value)) { // The goal is to write floormod(a,b) in terms of truncdiv and truncmod, // without using negative operands. // @@ -244,7 +246,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // floormod(a,b) // == floormod(a + b*c, b) // == truncmod(a + b*c, b) - IntImm min(op->a->dtype, const_int_bound->min_value); + IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); return truncmod(offset_numerator, op->b); diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index a97345da1a29..d9ea36206b06 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -349,5 +349,23 @@ def test_multiple_condition(): assert bound.min_value == 0 +def test_broadcast_bound(): + analyzer = tvm.arith.Analyzer() + a = te.var("a") + analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) + bound = analyzer.const_int_bound(tvm.tir.Broadcast(a, 4)) + assert bound.min_value == 0 + assert bound.max_value == 128 + + +def test_ramp_bound(): + analyzer = tvm.arith.Analyzer() + a = te.var("a") + analyzer.update(a, tvm.arith.ConstIntBound(0, 128)) + bound = analyzer.const_int_bound(tvm.tir.Ramp(a, 2, 4) + 2) + assert bound.min_value == 2 + assert bound.max_value == 128 + 2 * 3 + 2 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py index 27d0c037edf3..3b1cdb4422c5 100644 --- a/tests/python/unittest/test_target_codegen_metal.py +++ b/tests/python/unittest/test_target_codegen_metal.py @@ -18,15 +18,10 @@ from tvm import te import numpy as np -from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 -from tvm.contrib import nvcc import tvm.testing import tvm.script from tvm.script import tir as T -tx = te.thread_axis("threadIdx.x") -bx = te.thread_axis("blockIdx.x") - @tvm.testing.requires_gpu @tvm.testing.requires_metal @@ -37,9 +32,11 @@ def check_inf_nan(dev, n, value, dtype): A = te.placeholder((n,), name="A", dtype=dtype) inf_value = tvm.tir.const(value, dtype=dtype) C = te.compute((n,), lambda i: inf_value, name="C") - s = te.create_schedule(C.op) - s[C].bind(s[C].op.axis[0], tx) - fun = tvm.build(s, [A, C], target) + prim_func = te.create_prim_func([A, C]) + sch = tvm.tir.Schedule(prim_func) + (x,) = sch.get_loops(sch.get_block("C")) + sch.bind(x, "threadIdx.x") + fun = tvm.build(sch.mod, target=target) a = tvm.nd.empty((n,), A.dtype, dev) c = tvm.nd.empty((n,), A.dtype, dev) # Only need to test compiling here @@ -88,9 +85,11 @@ def test_metal_erf(): def check_erf(dev, n, dtype): A = te.placeholder((n,), name="A", dtype=dtype) C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") - s = te.create_schedule(C.op) - s[C].bind(s[C].op.axis[0], tx) - fun = tvm.build(s, [A, C], target) + func = te.create_prim_func([A, C]) + sch = tvm.tir.Schedule(func) + (x,) = sch.get_loops(sch.get_block("C")) + sch.bind(x, "threadIdx.x") + fun = tvm.build(sch.mod, target=target) a = tvm.nd.empty((n,), A.dtype, dev) c = tvm.nd.empty((n,), A.dtype, dev) # Only need to test compiling here @@ -125,6 +124,31 @@ def main(A: T.Buffer((1, 2), "int32")): assert tuple(a_nd.numpy()[0, :]) == (0, 3) +@tvm.testing.requires_gpu +@tvm.testing.requires_metal +def test_select_vectorize(): + @tvm.script.ir_module + class IRModule: + @T.prim_func + def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")): + T.func_attr({"global_symbol": "main"}) + for i0_1 in T.thread_binding(3, thread="threadIdx.x"): + for i0_0 in T.vectorized(2): + with T.block("block"): + vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1) + B[vi0] = T.Select((vi0 % 2) == 0, A[vi0], T.float32(0)) + + target = "metal" + dev = tvm.metal() + a = np.arange(6).astype("float32") + a_nd = tvm.nd.array(a, dev) + b_nd = tvm.nd.empty((6,), "float32", dev) + f = tvm.build(IRModule, target=target) + f(a_nd, b_nd) + a.reshape(3, 2)[:, 1] = 0 + np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_ramp() test_metal_inf_nan()