Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions examples/graph/gated_mlp_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
// Incremental IDs used to create logical tensors and operations.
size_t id = 0;

// Intermediate data type
const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32;

// dequantize for fc_gate weights
auto wei0_int4 = logical_tensor(
id++, data_type::u4, wei0_sz, layout_type::strided);
Expand All @@ -130,7 +133,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,

// fc_gate
auto src = logical_tensor(id++, dt, src_sz, layout_type::strided);
auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out0 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_gate = op(id++, op::kind::MatMul, "fc_gate");
fc_gate.add_inputs({src, wei0_dt});
fc_gate.add_outputs({out0});
Expand All @@ -151,29 +154,38 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
deq_up.add_outputs({wei1_dt});

// fc_up
auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out1 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_up = op(id++, op::kind::MatMul, "fc_up");
fc_up.add_inputs({src, wei1_dt});
fc_up.add_outputs({out1});

// activation swish: sigmoid
auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out2 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid");
swi_sig.add_inputs({out0});
swi_sig.add_outputs({out2});

// activation swish: multiply
auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out3 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply");
swi_mul.add_inputs({out0, out2});
swi_mul.add_outputs({out3});

// multiplication
auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out4 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto mul = op(id++, op::kind::Multiply, "mul");
mul.add_inputs({out3, out1});
mul.add_outputs({out4});

// downconversion when needed
auto out4_dt = out4;
auto typecast = op(id++, op::kind::TypeCast, "typecast");
if (dt != dt_inter) {
out4_dt = logical_tensor(id++, dt, hd_sz, layout_type::strided);
typecast.add_inputs({out4});
typecast.add_outputs({out4_dt});
}

// dequantize for fc_down weights
auto wei2_int4 = logical_tensor(
id++, data_type::u4, wei2_sz, layout_type::strided);
Expand All @@ -192,7 +204,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
// fc_down
auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided);
auto fc_down = op(id++, op::kind::MatMul, "fc_down");
fc_down.add_inputs({out4, wei2_dt});
fc_down.add_inputs({out4_dt, wei2_dt});
fc_down.add_outputs({dst});

// Construct a gated mlp graph with engine kind and operations.
Expand All @@ -205,6 +217,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
mlp.add_op(swi_sig);
mlp.add_op(swi_mul);
mlp.add_op(mul);
if (dt != dt_inter) { mlp.add_op(typecast); }
mlp.add_op(deq_down);
mlp.add_op(fc_down);
mlp.finalize();
Expand Down
30 changes: 30 additions & 0 deletions src/graph/backend/dnnl/executables/gated_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,39 @@ arg_indices_t gated_mlp_executable_t::get_arg_indices(const op_t *op) {
args.insert({DNNL_ARG_WEIGHTS_UP, {indices_t::type_t::input, idx++}});
args.insert({DNNL_ARG_WEIGHTS_DOWN, {indices_t::type_t::input, idx++}});

// optional scales/zps for quantization
const auto &fusion_info = op->has_attr(op_attr::fusion_info)
? op->get_attr<fusion_info_t>(op_attr::fusion_info)
: fusion_info_t();
if (fusion_info.with_runtime_scales(true, DNNL_ARG_WEIGHTS_GATE)) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS_GATE,
{indices_t::type_t::input, idx++}});
}
if (fusion_info.with_runtime_zero_points(true, DNNL_ARG_WEIGHTS_GATE)) {
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS_GATE,
{indices_t::type_t::input, idx++}});
}
if (fusion_info.with_runtime_scales(true, DNNL_ARG_WEIGHTS_UP)) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS_UP,
{indices_t::type_t::input, idx++}});
}
if (fusion_info.with_runtime_zero_points(true, DNNL_ARG_WEIGHTS_UP)) {
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS_UP,
{indices_t::type_t::input, idx++}});
}
if (fusion_info.with_runtime_scales(true, DNNL_ARG_WEIGHTS_DOWN)) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS_DOWN,
{indices_t::type_t::input, idx++}});
}
if (fusion_info.with_runtime_zero_points(true, DNNL_ARG_WEIGHTS_DOWN)) {
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS_DOWN,
{indices_t::type_t::input, idx++}});
}

// outputs
args.insert({DNNL_ARG_DST, {indices_t::type_t::output, 0}});
args.insert({DNNL_ARG_SCRATCHPAD, {indices_t::type_t::output, 1}});

return args;
}

Expand Down
4 changes: 3 additions & 1 deletion src/graph/backend/dnnl/kernels/gated_mlp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace impl {
namespace graph {
namespace dnnl_impl {

template <bool quantized = false>
struct gated_mlp_base_t : public kernel_base_t {
private:
std::shared_ptr<kernel_base_t> kernel;
Expand All @@ -53,7 +54,8 @@ struct gated_mlp_base_t : public kernel_base_t {
status_t ret = status::unimplemented;

if (enable_ukernel) {
kernel = std::make_shared<gated_mlp_primitive_kernel_t>();
kernel = std::make_shared<
gated_mlp_primitive_kernel_t<quantized>>();
ret = kernel->compile_impl(part, engine, inputs, outputs);
}

Expand Down
34 changes: 26 additions & 8 deletions src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ namespace impl {
namespace graph {
namespace dnnl_impl {

status_t gated_mlp_primitive_kernel_t::compile_impl(
template <bool quantized>
status_t gated_mlp_primitive_kernel_t<quantized>::compile_impl(
const dnnl_partition_impl_t *part, const engine_t *eng,
const std::vector<logical_tensor_t> &inputs,
const std::vector<logical_tensor_t> &outputs) {
Expand All @@ -62,6 +63,16 @@ status_t gated_mlp_primitive_kernel_t::compile_impl(
pass_pipeline_t pipeline = pass_pipeline_t(vis);

BACKEND_DNNL_ADD_PASS(pipeline, lower_down);

if (quantized) {
BACKEND_DNNL_ADD_PASS(pipeline, fuse_typecast_to_matmul_or_conv);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_typecast_to_predecessor);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_scales);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_zero_points);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_scales);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_zero_points);
}

pipeline.reset_visualize_arg(true, false);
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_gated_mlp);
Expand Down Expand Up @@ -97,7 +108,8 @@ status_t gated_mlp_primitive_kernel_t::compile_impl(
return status::success;
}

void gated_mlp_primitive_kernel_t::prepare_args_set(
template <bool quantized>
void gated_mlp_primitive_kernel_t<quantized>::prepare_args_set(
const execution_args_set_t *res, const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs, const scratchpad_t &scratchpad) {
// update the data of partition in/outputs args
Expand Down Expand Up @@ -129,8 +141,9 @@ void gated_mlp_primitive_kernel_t::prepare_args_set(
}
}

status_t gated_mlp_primitive_kernel_t::execute_impl(const stream_t *stream,
const std::vector<tensor_t> &inputs,
template <bool quantized>
status_t gated_mlp_primitive_kernel_t<quantized>::execute_impl(
const stream_t *stream, const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs) {
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *stream);

Expand All @@ -151,8 +164,9 @@ status_t gated_mlp_primitive_kernel_t::execute_impl(const stream_t *stream,
}

#ifdef DNNL_WITH_SYCL
status_t gated_mlp_primitive_kernel_t::sycl_execute_impl(const stream_t *stream,
const std::vector<tensor_t> &inputs,
template <bool quantized>
status_t gated_mlp_primitive_kernel_t<quantized>::sycl_execute_impl(
const stream_t *stream, const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs,
const std::vector<::sycl::event> &sycl_deps, ::sycl::event *ret_event) {
// gated_mlp_primitive_kernel_t only supports Intel GPU.
Expand Down Expand Up @@ -188,8 +202,9 @@ status_t gated_mlp_primitive_kernel_t::sycl_execute_impl(const stream_t *stream,
#endif

#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
status_t gated_mlp_primitive_kernel_t::ocl_execute_impl(const stream_t *stream,
const std::vector<tensor_t> &inputs,
template <bool quantized>
status_t gated_mlp_primitive_kernel_t<quantized>::ocl_execute_impl(
const stream_t *stream, const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs,
const std::vector<cl_event> &ocl_deps, cl_event *ret_event) {
auto deps = ocl_deps;
Expand Down Expand Up @@ -220,6 +235,9 @@ status_t gated_mlp_primitive_kernel_t::ocl_execute_impl(const stream_t *stream,
}
#endif

template struct gated_mlp_primitive_kernel_t<true>;
template struct gated_mlp_primitive_kernel_t<false>;

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace impl {
namespace graph {
namespace dnnl_impl {

template <bool quantized>
struct gated_mlp_primitive_kernel_t : public kernel_base_t {
private:
allocator_t *g_alloc_ = nullptr;
Expand Down
69 changes: 69 additions & 0 deletions src/graph/backend/dnnl/passes/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4584,6 +4584,10 @@ status_t fuse_sdpa(std::shared_ptr<subgraph_t> &sg) {
return status::success;
}

#define DNNL_ARG_WEIGHTS_GATE DNNL_ARG_WEIGHTS_0
#define DNNL_ARG_WEIGHTS_UP DNNL_ARG_WEIGHTS_1
#define DNNL_ARG_WEIGHTS_DOWN DNNL_ARG_WEIGHTS_2

// The pass is called against a gated mlp subgraph matched by the gated mlp
// patterns. Hence we have the basic assumptions for the topology which
// simplifies the pass logic below.
Expand Down Expand Up @@ -4656,11 +4660,64 @@ status_t fuse_gated_mlp(std::shared_ptr<subgraph_t> &sg) {
return status::unimplemented;
}

fusion_info_t fusion_info;
if (gate->has_attr(op_attr::fusion_info)) {
auto gate_fusion_info
= gate->get_attr<fusion_info_t>(op_attr::fusion_info);
if (gate_fusion_info.get_mutable_scales(true, 1)) {
fusion_info.set_runtime_scales(
gate_fusion_info.get_mutable_scales(true, 1)
->shared_from_this(),
true, DNNL_ARG_WEIGHTS_GATE);
}
if (gate_fusion_info.with_runtime_zero_points(true, 1)) {
fusion_info.set_zero_points(
gate_fusion_info.get_mutable_zero_points(true, 1)
->shared_from_this(),
true, DNNL_ARG_WEIGHTS_GATE);
}
}

if (up->has_attr(op_attr::fusion_info)) {
auto up_fusion_info = up->get_attr<fusion_info_t>(op_attr::fusion_info);
if (up_fusion_info.get_mutable_scales(true, 1)) {
fusion_info.set_runtime_scales(
up_fusion_info.get_mutable_scales(true, 1)
->shared_from_this(),
true, DNNL_ARG_WEIGHTS_UP);
}
if (up_fusion_info.with_runtime_zero_points(true, 1)) {
fusion_info.set_zero_points(
up_fusion_info.get_mutable_zero_points(true, 1)
->shared_from_this(),
true, DNNL_ARG_WEIGHTS_UP);
}
}

if (down->has_attr(op_attr::fusion_info)) {
auto down_fusion_info
= down->get_attr<fusion_info_t>(op_attr::fusion_info);
if (down_fusion_info.get_mutable_scales(true, 1)) {
fusion_info.set_runtime_scales(
down_fusion_info.get_mutable_scales(true, 1)
->shared_from_this(),
true, DNNL_ARG_WEIGHTS_DOWN);
}
if (down_fusion_info.with_runtime_zero_points(true, 1)) {
fusion_info.set_zero_points(
down_fusion_info.get_mutable_zero_points(true, 1)
->shared_from_this(),
true, DNNL_ARG_WEIGHTS_DOWN);
}
}

subgraph_rewriter_t rewriter(sg);
op_ptr gated_mlp_op = std::make_shared<op_t>(op_kind::_gated_mlp);
gated_mlp_op->set_attr<int64_t>(
op_attr::alg_kind, static_cast<int64_t>(act_algo));

gated_mlp_op->set_attr<fusion_info_t>(op_attr::fusion_info, fusion_info);

// connect inputs and outputs
auto src_val = gate->get_input_value(0);
auto wei0_val = gate->get_input_value(1);
Expand All @@ -4674,6 +4731,18 @@ status_t fuse_gated_mlp(std::shared_ptr<subgraph_t> &sg) {
gated_mlp_op->connect_input(1, wei0_val);
gated_mlp_op->connect_input(2, wei1_val);
gated_mlp_op->connect_input(3, wei2_val);

size_t input_idx = 4;
// Handle quantization parameters from matmuls
for (const auto &matmul : {gate, up, down}) {
auto inputs = matmul->get_input_values();
for (size_t idx = 2; idx < inputs.size(); ++idx) {
const auto &qparam_val = inputs[idx];
qparam_val->remove_consumer(*matmul, idx);
gated_mlp_op->connect_input(input_idx++, qparam_val);
}
}

auto dst_val = down->get_output_value(0);
dst_val->set_producer(*gated_mlp_op);
gated_mlp_op->add_output(dst_val);
Expand Down
Loading
Loading