Skip to content

Commit

Permalink
[CUDA] Fixed the call of the min function in the schedule for cuda (a…
Browse files Browse the repository at this point in the history
…pache#14751)

* fixed the call of the minimum function in the schedule for cuda

* add test for scatter_nd

* update test only for cuda target

* fix lint

* update test

* fix lint

* apply comments
  • Loading branch information
valmat07 authored May 15, 2023
1 parent 602133e commit b4c1c38
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
fused_shape *= i

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_updates_dimension)

tdim = tvm.tir.min(max_threads, fused_updates_dimension)
with ib.new_scope():
bdim = ceil_div(fused_shape, tdim)
bx = te.thread_axis("blockIdx.x")
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,29 @@ def verify_scatter_nd(data_np, indices_np, updates_np, ref_res):
verify_scatter_nd(data, indices, updates, out)


@tvm.testing.uses_gpu
def test_scatter_nd_any_updates():
def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res):
indices_shape = (2, relay.Any())
updates_shape = (2, relay.Any())
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype)))
updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype)))

out = relay.op.scatter_nd(data, indices, updates, "add")

mod = tvm.IRModule()
mod["main"] = relay.Function([data, indices, updates], out)

check_result([data_np, indices_np, updates_np], mod, [ref_res], only_vm=True)

data = np.zeros((3, 3)).astype("int64")
indices = np.array([[1, 1], [0, 1]])
updates = np.array([[2, 2], [1, 1]])
out = np.array([[0, 0, 0], [0, 0, 0], [2, 2, 1]])
verify_scatter_nd_any_updates(data, indices, updates, out)


@tvm.testing.uses_gpu
def test_gather():
def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, axis):
Expand Down

0 comments on commit b4c1c38

Please sign in to comment.