Skip to content

Commit e437944

Browse files
committed
graph: backend: dnnl: patterns: enable quantized gated mlp kernels
1 parent ff4faea commit e437944

1 file changed

Lines changed: 4 additions & 5 deletions

File tree

  • src/graph/backend/dnnl/patterns

src/graph/backend/dnnl/patterns/mlp.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*******************************************************************************/
1616

1717
#include "graph/backend/dnnl/kernels/gated_mlp.hpp"
18-
#include "graph/backend/dnnl/kernels/large_partition.hpp"
1918

2019
#include "graph/backend/dnnl/patterns/fusions.hpp"
2120
#include "graph/backend/dnnl/patterns/pattern_matcher_pass.hpp"
@@ -87,7 +86,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp)
8786
in_edges_t {in_edge(0, pre_tc, 0)});
8887
})
8988
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
90-
return std::make_shared<gated_mlp_base_t>();
89+
return std::make_shared<gated_mlp_base_t<false>>();
9190
});
9291

9392
// gated mlp with swish decomposed to sigmoid and multiply.
@@ -131,7 +130,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp_v1)
131130
in_edges_t {in_edge(0, pre_tc, 0)});
132131
})
133132
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
134-
return std::make_shared<gated_mlp_base_t>();
133+
return std::make_shared<gated_mlp_base_t<false>>();
135134
});
136135

137136
/*
@@ -195,7 +194,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp)
195194
pgraph->append_op(graph::op_kind::MatMul, fc_down_edges);
196195
})
197196
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
198-
return std::make_shared<larger_partition_kernel_t>();
197+
return std::make_shared<gated_mlp_base_t<true>>();
199198
});
200199

201200
// quantized gated mlp with swish decomposed to sigmoid and multiply.
@@ -246,7 +245,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp_v1)
246245
pgraph->append_op(graph::op_kind::MatMul, fc_down_edges);
247246
})
248247
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
249-
return std::make_shared<larger_partition_kernel_t>();
248+
return std::make_shared<gated_mlp_base_t<true>>();
250249
});
251250

252251
DNNL_BACKEND_REGISTER_PATTERN_DEF_END

0 commit comments

Comments
 (0)