Skip to content

Commit 133e493

Browse files
authored
[OneDNN] onednn w4a16 int4 (#49)
* add w4a16_int4 Signed-off-by: Zhu, Zufang <[email protected]> * fix rebase error Signed-off-by: Zhu, Zufang <[email protected]> * add ut for gemm kernel Signed-off-by: Zhu, Zufang <[email protected]> * remove error comment Signed-off-by: Zhu, Zufang <[email protected]> --------- Signed-off-by: Zhu, Zufang <[email protected]>
1 parent 08312b7 commit 133e493

File tree

15 files changed

+739
-821
lines changed

15 files changed

+739
-821
lines changed

csrc/xpu/onednn/fp8_gemm_w8a16.cpp

Lines changed: 0 additions & 54 deletions
This file was deleted.

csrc/xpu/onednn/fp8_gemm_w8a16.h

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@ using GpuStreamManager = at::native::onednn::GpuStreamManager;
1313
using GpuEngineManager = at::native::onednn::GpuEngineManager;
1414

1515
static inline void dnnl_matmul_w8a16_fp8(
16-
torch::Tensor& result, const torch::Tensor& mat1, const torch::Tensor& mat2,
16+
torch::Tensor& result, // dst, [b, m, n]
17+
const torch::Tensor& mat1, // src, [b, m, k]
18+
const torch::Tensor& mat2, // quantized weight, [k, n] transpose
1719
bool trans_b, const std::optional<torch::Tensor>& bias,
1820
const torch::Tensor& m2_sc, const int64_t group_size = 0) {
19-
TORCH_CHECK(mat2.scalar_type() == at::ScalarType::Float8_e5m2 ||
20-
mat2.scalar_type() == at::ScalarType::Float8_e4m3fn,
21-
"weight must be f8_e5m2 or f8_e4m3fn for fp8 matmul");
2221
auto src_sz = mat1.sizes();
2322
auto o_sz = result.sizes();
2423

@@ -44,45 +43,7 @@ static inline void dnnl_matmul_w8a16_fp8(
4443
}
4544

4645
// get bias type
47-
bias_shape_t bias_shape;
48-
bias_data_type_t bias_dtype;
49-
if (bias.has_value() && bias.value().defined()) {
50-
auto& b = bias.value();
51-
const auto nuelm = b.numel();
52-
if (nuelm == 1) {
53-
bias_shape = bias_shape_t::scalar;
54-
} else if (nuelm == m * n) {
55-
bias_shape = bias_shape_t::mn;
56-
} else if (b.size(b.dim() - 1) == n && nuelm == n) {
57-
bias_shape = bias_shape_t::n;
58-
} else if (b.size(b.dim() - 1) == 1 && nuelm == m) {
59-
bias_shape = bias_shape_t::m;
60-
} else if (nuelm == 0) {
61-
bias_shape = bias_shape_t::none;
62-
} else {
63-
TORCH_CHECK(0, "unsupported bias dim in matmul ...", b.sizes());
64-
}
65-
66-
switch (b.scalar_type()) {
67-
case at::ScalarType::Float:
68-
bias_dtype = bias_data_type_t::f32;
69-
break;
70-
case at::ScalarType::BFloat16:
71-
bias_dtype = bias_data_type_t::bf16;
72-
break;
73-
case at::ScalarType::Half:
74-
bias_dtype = bias_data_type_t::f16;
75-
break;
76-
default:
77-
TORCH_CHECK(false, "Unsupported data type for bias in fp8 matmul: ",
78-
b.scalar_type());
79-
}
80-
} else {
81-
bias_shape = bias_shape_t::none;
82-
bias_dtype = bias_data_type_t::none;
83-
}
84-
85-
bias_type_t b_type = make_bias_type(bias_shape, bias_dtype);
46+
bias_type_t b_type = get_bias_type(bias, m, n);
8647

8748
trans_type_t tt = trans_type_t::nn;
8849
if (trans_b) {
@@ -137,7 +98,7 @@ static inline void dnnl_matmul_w8a16_fp8(
13798
arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr());
13899
arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr());
139100
arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr());
140-
if (bias_shape != bias_shape_t::none) {
101+
if (get_shape(b_type) != bias_shape_t::none) {
141102
arg_handles.emplace_back(DNNL_ARG_BIAS, bias.value().data_ptr());
142103
}
143104
int scratchpad_size = matmul_ext.get_scratchpad_size();

csrc/xpu/onednn/int4_gemm_w4a16.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#pragma once
2+
3+
#include <c10/xpu/XPUStream.h>
4+
#include <dnnl.hpp>
5+
#include <torch/torch.h>
6+
7+
#include "onednn_ext.h"
8+
9+
namespace oneDNN {
10+
11+
using trans_type_t = at::native::onednn::trans_type_t;
12+
using GpuStreamManager = at::native::onednn::GpuStreamManager;
13+
using GpuEngineManager = at::native::onednn::GpuEngineManager;
14+
15+
static inline void dnnl_matmul_w4a16_int4(
16+
torch::Tensor& result, // dst, [b, m, n]
17+
const torch::Tensor& mat1, // src, [b, m, k]
18+
const torch::Tensor& mat2, // quantized weight, [k/8, n] transpose
19+
bool trans_b, const std::optional<torch::Tensor>& bias,
20+
const torch::Tensor& scale, // [k/group_size, n]
21+
const torch::Tensor& zp, // [k/group_size, n/8]
22+
int64_t group_size) {
23+
auto src_sz = mat1.sizes();
24+
auto o_sz = result.sizes();
25+
26+
const int m = std::reduce(src_sz.begin(), src_sz.end() - 1, 1,
27+
std::multiplies<int64_t>());
28+
const int n = o_sz.back(); // presume channel last format
29+
const int k = *(src_sz.end() - 1);
30+
31+
// get joint dtypes
32+
joint_dtypes_t jd;
33+
auto in_dtype = mat1.scalar_type();
34+
if (in_dtype == at::ScalarType::Half) {
35+
jd = joint_dtypes_t::f16_int4;
36+
} else if (in_dtype == at::ScalarType::BFloat16) {
37+
jd = joint_dtypes_t::bf16_int4;
38+
} else if (in_dtype == at::ScalarType::Char) {
39+
jd = joint_dtypes_t::s8_int4;
40+
} else if (in_dtype == at::ScalarType::Byte) {
41+
jd = joint_dtypes_t::u8_int4;
42+
} else {
43+
TORCH_INTERNAL_ASSERT(false,
44+
"Unsupported data type for int4 matmul: ", in_dtype);
45+
}
46+
47+
// get bias type
48+
bias_type_t b_type = get_bias_type(bias, m, n);
49+
50+
trans_type_t tt = trans_type_t::nn;
51+
if (trans_b) {
52+
// transpose mat2
53+
tt = trans_type_t::nt;
54+
}
55+
56+
// get lda ldb and ldc
57+
auto mat1_strides = mat1.strides();
58+
int64_t leading_dim = -1;
59+
if (mat1.dim() == 2) {
60+
leading_dim = 0;
61+
} else if (mat1.dim() == 3) {
62+
leading_dim = mat1_strides[0] < mat1_strides[1] ? 0 : 1;
63+
} else {
64+
TORCH_CHECK(false,
65+
"Unsupported input dimension for int4 matmul: ", mat1.dim());
66+
}
67+
int64_t lda = mat1_strides[leading_dim];
68+
int64_t ldb = mat2.strides()[mat2.dim() - 1] == 1
69+
? mat2.strides()[mat2.dim() - 2] * 8
70+
: mat2.strides()[mat2.dim() - 1] * 8; // for int4 matmul
71+
int64_t ldc = result.strides()[leading_dim];
72+
73+
auto f_attr = [&](primitive_attr& pattr) {
74+
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
75+
pattr.set_scales(DNNL_ARG_WEIGHTS,
76+
/* mask */ (1 << 0) + (1 << 1), {group_size, 1},
77+
get_onednn_dtype(scale));
78+
if (zp.dim() == 1) {
79+
pattr.set_zero_points(DNNL_ARG_WEIGHTS,
80+
/* mask */ 0, {}, memory::data_type::s8);
81+
} else {
82+
pattr.set_zero_points(DNNL_ARG_WEIGHTS,
83+
/* mask */ (1 << 0) + (1 << 1), {group_size, 1},
84+
memory::data_type::u4);
85+
}
86+
pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true);
87+
if (in_dtype == at::ScalarType::BFloat16) {
88+
pattr.set_fpmath_mode(dnnl::fpmath_mode::bf16, true);
89+
} else if (in_dtype == at::ScalarType::Half) {
90+
pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true);
91+
} else {
92+
TORCH_INTERNAL_ASSERT(
93+
false, "Unsupported data type for int4 matmul: ", in_dtype);
94+
}
95+
};
96+
97+
// ************************************************************
98+
// get device, engine, stream
99+
const int dev_id = c10::xpu::getCurrentXPUStream().device_index();
100+
at::Device curDevice = at::Device(at::kXPU, dev_id);
101+
auto engine = GpuEngineManager::Instance().get_engine(curDevice);
102+
int64_t zp_group_size = zp.dim() == 1 ? 1 : group_size;
103+
auto& matmul_ext = matmul_primitive_create_and_cache(
104+
jd, tt, b_type, m, n, k, lda, ldb, ldc, dev_id, f_attr, group_size,
105+
zp_group_size);
106+
107+
int arg_off = 0;
108+
// set scale and zero point for matmul args
109+
matmul_ext.set_attribute(arg_off++, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS,
110+
scale.data_ptr(), [&]() {
111+
return at::native::onednn::make_onednn_memory(
112+
get_onednn_md(scale), engine,
113+
scale.data_ptr());
114+
});
115+
116+
if (zp.dim() == 1) {
117+
// set zp_md for symmetric quantization
118+
matmul_ext.set_attribute(arg_off++,
119+
DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS,
120+
zp.data_ptr(), [&]() {
121+
return at::native::onednn::make_onednn_memory(
122+
get_onednn_md(zp), engine, zp.data_ptr());
123+
});
124+
} else {
125+
// set zp_md for asymmetric quantization
126+
matmul_ext.set_attribute(
127+
arg_off++, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp.data_ptr(),
128+
[&]() {
129+
auto num_groups = k / group_size;
130+
dnnl::memory zp_B_u4_m(
131+
{{num_groups, n}, memory::data_type::u4, {n, 1}}, engine,
132+
zp.data_ptr());
133+
return zp_B_u4_m;
134+
});
135+
}
136+
137+
// set general args
138+
std::vector<std::pair<int, void*>> arg_handles;
139+
arg_handles.reserve(8);
140+
141+
arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr());
142+
arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr());
143+
arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr());
144+
if (get_shape(b_type) != bias_shape_t::none) {
145+
arg_handles.emplace_back(DNNL_ARG_BIAS, bias.value().data_ptr());
146+
}
147+
148+
int scratchpad_size = matmul_ext.get_scratchpad_size();
149+
torch::Tensor scratchpad_tensor = at::empty(
150+
{scratchpad_size}, mat1.options().dtype(at::kByte), c10::nullopt);
151+
arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr());
152+
153+
auto& strm = GpuStreamManager::Instance().get_stream();
154+
matmul_ext.execute(strm, engine, std::move(arg_handles), arg_off);
155+
}
156+
} // namespace oneDNN

0 commit comments

Comments
 (0)