Skip to content

Commit

Permalink
[METAL] Fix vectorized select (apache#14846)
Browse files Browse the repository at this point in the history
This PR fixes the codegen for vectorized select in metal.
Also enhances arithmetics to cover better constant bound.
  • Loading branch information
tqchen authored May 14, 2023
1 parent 9f0c642 commit 28c85f0
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class ConstIntBoundAnalyzer::Impl
return Union(a, b);
}

Entry VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); }

Entry VisitExpr_(const CastNode* op) final {
Entry a;

Expand Down
5 changes: 5 additions & 0 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os)
}
}

void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
os << "select(" << PrintExpr(op->false_value) << ", " << PrintExpr(op->true_value) << ", "
<< PrintExpr(op->condition) << ")";
}

void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
PrintType(op->dtype, os);
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class CodeGenMetal final : public CodeGenC {
// print store of single element.
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
Expand Down
2 changes: 2 additions & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
}

private:
using IRMutatorWithAnalyzer::VisitExpr;
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_;

Expand Down
10 changes: 6 additions & 4 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// in terms of truncdiv using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value > -(Downcast<IntImm>(tvm::max_value(op->a->dtype))->value)) {
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without using
// negative operands.
//
Expand Down Expand Up @@ -150,7 +151,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype, const_int_bound->min_value);
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
Expand Down Expand Up @@ -214,7 +215,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// in terms of truncmod using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value < 0 &&
const_int_bound->min_value > -(Downcast<IntImm>(tvm::max_value(op->a->dtype))->value)) {
const_int_bound->min_value >
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
// without using negative operands.
//
Expand Down Expand Up @@ -244,7 +246,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
IntImm min(op->a->dtype, const_int_bound->min_value);
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncmod(offset_numerator, op->b);
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,5 +349,23 @@ def test_multiple_condition():
assert bound.min_value == 0


def test_broadcast_bound():
analyzer = tvm.arith.Analyzer()
a = te.var("a")
analyzer.update(a, tvm.arith.ConstIntBound(0, 128))
bound = analyzer.const_int_bound(tvm.tir.Broadcast(a, 4))
assert bound.min_value == 0
assert bound.max_value == 128


def test_ramp_bound():
analyzer = tvm.arith.Analyzer()
a = te.var("a")
analyzer.update(a, tvm.arith.ConstIntBound(0, 128))
bound = analyzer.const_int_bound(tvm.tir.Ramp(a, 2, 4) + 2)
assert bound.min_value == 2
assert bound.max_value == 128 + 2 * 3 + 2


if __name__ == "__main__":
tvm.testing.main()
46 changes: 35 additions & 11 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@
from tvm import te
import numpy as np

from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing
import tvm.script
from tvm.script import tir as T

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
Expand All @@ -37,9 +32,11 @@ def check_inf_nan(dev, n, value, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
inf_value = tvm.tir.const(value, dtype=dtype)
C = te.compute((n,), lambda i: inf_value, name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
prim_func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(prim_func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
Expand Down Expand Up @@ -88,9 +85,11 @@ def test_metal_erf():
def check_erf(dev, n, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
Expand Down Expand Up @@ -125,6 +124,31 @@ def main(A: T.Buffer((1, 2), "int32")):
assert tuple(a_nd.numpy()[0, :]) == (0, 3)


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_select_vectorize():
@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")):
T.func_attr({"global_symbol": "main"})
for i0_1 in T.thread_binding(3, thread="threadIdx.x"):
for i0_0 in T.vectorized(2):
with T.block("block"):
vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1)
B[vi0] = T.Select((vi0 % 2) == 0, A[vi0], T.float32(0))

target = "metal"
dev = tvm.metal()
a = np.arange(6).astype("float32")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((6,), "float32", dev)
f = tvm.build(IRModule, target=target)
f(a_nd, b_nd)
a.reshape(3, 2)[:, 1] = 0
np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_ramp()
test_metal_inf_nan()
Expand Down

0 comments on commit 28c85f0

Please sign in to comment.