Skip to content

Commit

Permalink
[ROCm] use cupy for GPU-accelerated computing (microsoft#16611)
Browse files Browse the repository at this point in the history
kernel explorer has lots of tests and need numpy to verify the results
of GPU kernels, it will make CPU utilization very high. This PR use
`cupy ` to replace `numpy` to do compute on GPU to reduce CPU
utilization.

set `KERNEL_EXPLORER_TEST_USE_CUPY=1` to enable cupy.
  • Loading branch information
PeixuanZuo authored Jul 10, 2023
1 parent 5fee3f4 commit 3b729e5
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix


def dtype_to_suffix(dtype):
Expand All @@ -31,7 +31,7 @@ def _test_batched_gemm(
np.random.seed(0)
as_ = [(np.random.rand(*a_shape) + 0.5).astype(dtype).astype("float64") for i in range(batch)]
bs = [(np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64") for i in range(batch)]
intermediate_cs = [(as_[i].T if transa else as_[i]) @ (bs[i].T if transb else bs[i]) for i in range(batch)]
intermediate_cs = [matmul(as_[i], bs[i], transa, transb) for i in range(batch)]
if alpha == 1.0 and beta == 0.0: # fast path
ref_cs = intermediate_cs
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_bytes
from utils import dtype_to_bytes, fast_gelu


def get_bert_sizes():
Expand All @@ -30,12 +30,6 @@ def dtype_to_funcs(dtype):
return type_map[dtype]


def fast_gelu(x, bias):
x = x + bias
y = 0.5 * x * (1 + np.tanh(0.797885 * x + 0.035677 * x * x * x))
return y


def run_fast_gelu(x_size, bias_size, dtype, func):
np.random.seed(0)
x = np.random.rand(*x_size).astype(dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix


def fast_gelu(x, bias):
x = x + bias
y = 0.5 * x * (1 + np.tanh(0.797885 * x + 0.035677 * x * x * x))
return y
from utils import (
dtype_to_suffix,
fast_gelu,
get_gemm_basic_sizes,
get_gemm_bert_sizes,
get_gemm_bound,
matmul,
transab_to_suffix,
)


# TODO The test method needs update.
Expand All @@ -30,7 +32,7 @@ def _test_gemmfastgelu(my_func, dtype: str, m: int, n: int, k: int, transa=False
a = (np.random.rand(*a_shape)).astype(dtype).astype("float64")
b = (np.random.rand(*b_shape)).astype(dtype).astype("float64")
bias = (np.random.rand(n)).astype(dtype)
temp_c = (a.T if transa else a) @ (b.T if transb else b)
temp_c = matmul(a, b, transa, transb)

bound = get_gemm_bound(dtype, a, b, temp_c, transa, transb, a_b_positive=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_suffix, softmax
from utils import dtype_to_suffix, matmul, softmax


def multinormal_distribution(num_distribution, num_element_per_dist):
Expand Down Expand Up @@ -117,7 +117,7 @@ def _test_gemm_softmax_gemm_permute(
if mask_shape is not None:
attn_mask = (np.random.randint(0, 100, size=mask_shape) < 95).astype(np.int32)

pre_softmax_attn_scores = q @ np.swapaxes(k, 2, 3)
pre_softmax_attn_scores = matmul(q, np.swapaxes(k, 2, 3))
pre_softmax_attn_scores = pre_softmax_attn_scores * scale
if attn_bias is not None:
pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias
Expand All @@ -130,7 +130,7 @@ def _test_gemm_softmax_gemm_permute(
converted_mask = (1 - attn_mask.reshape(mask_shape_broadcasted)) * filter_value
pre_softmax_attn_scores = pre_softmax_attn_scores + converted_mask
attn_scores = softmax(pre_softmax_attn_scores, axis=-1)
attn = attn_scores @ v
attn = matmul(attn_scores, v)
ref = np.swapaxes(attn, 2, 1) # permute 0213

out = np.empty(out_shape, dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix


def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, alpha=1.0, beta=0.0):
Expand All @@ -22,7 +22,7 @@ def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k:
np.random.seed(0)
a = (np.random.rand(*a_shape) + 0.5).astype(dtype).astype("float64")
b = (np.random.rand(*b_shape) + 0.5).astype(dtype).astype("float64")
intermediate_c = (a.T if transa else a) @ (b.T if transb else b)
intermediate_c = matmul(a, b, transa, transb)
if alpha == 1.0 and beta == 0.0: # fast path
ref_c = intermediate_c
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_bytes, dtype_to_suffix
from utils import dtype_to_bytes, dtype_to_suffix, standardization


def get_sd_sizes():
Expand Down Expand Up @@ -40,9 +40,7 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
input_x = input_x.transpose([0, 3, 1, 2])
assert c % num_groups == 0
x = input_x.reshape((n, num_groups, -1))
mean = np.mean(x, axis=-1, keepdims=True)
var = np.var(x, axis=-1, keepdims=True)
x = (x - mean) / np.sqrt(var + epsilon)
x = standardization(x, -1, epsilon)
x = x.reshape((n, c, h, w))
x = x.transpose([0, 2, 3, 1])
x = x * gamma + beta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import dtype_to_bytes
from utils import dtype_to_bytes, standardization


def get_bert_sizes_test():
Expand All @@ -38,10 +38,7 @@ def dtype_to_funcs(dtype):

def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon):
val = input_x + skip + bias
x_u = np.mean(val, axis=(2,))
x_s = np.var(val, axis=(2,))
output = val - x_u[..., None]
output = output / np.sqrt(x_s + epsilon)[..., None]
output = standardization(val, 2, epsilon)
output = output * gamma + beta
return output, val

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import kernel_explorer as ke
import numpy as np
import pytest
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix


def dtype_to_suffix(dtype):
Expand All @@ -31,7 +31,9 @@ def _test_strided_batched_gemm(
np.random.seed(0)
a = (np.random.rand(batch, *a_shape) + 0.5).astype(dtype).astype("float64")
b = (np.random.rand(batch, *b_shape) + 0.5).astype(dtype).astype("float64")
intermediate_c = (a.swapaxes(1, 2) if transa else a) @ (b.swapaxes(1, 2) if transb else b)
tmp_a = a.swapaxes(1, 2) if transa else a
tmp_b = b.swapaxes(1, 2) if transb else b
intermediate_c = matmul(tmp_a, tmp_b)
if alpha == 1.0 and beta == 0.0: # fast path
ref_c = intermediate_c
else:
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/python/tools/kernel_explorer/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os
from itertools import product

import numpy as np
Expand Down Expand Up @@ -100,3 +101,28 @@ def softmax(x, *, is_log_softmax=False, axis=-1):
if is_log_softmax:
return x - np.log(np.sum(np.exp(x), axis=axis, keepdims=1))
return (np.exp(x)) / np.sum(np.exp(x), axis=axis, keepdims=1)


def _matmul(a, b):
if os.getenv("KERNEL_EXPLORER_TEST_USE_CUPY", "0") == "1":
import cupy as cp

return (cp.asarray(a) @ cp.asarray(b)).get()
else:
return a @ b


def matmul(a, b, transa=False, transb=False):
return _matmul(a.T if transa else a, b.T if transb else b)


def fast_gelu(x, bias):
x = x + bias
y = 0.5 * x * (1 + np.tanh(0.797885 * x + 0.035677 * x * x * x))
return y


def standardization(x, axis, epsilon):
mean = np.mean(x, axis=axis, keepdims=True)
variance = np.var(x, axis=axis, keepdims=True)
return (x - mean) / np.sqrt(variance + epsilon)

0 comments on commit 3b729e5

Please sign in to comment.