|
15 | 15 | *******************************************************************************/ |
16 | 16 |
|
17 | 17 | #include "graph/backend/dnnl/kernels/gated_mlp.hpp" |
18 | | -#include "graph/backend/dnnl/kernels/large_partition.hpp" |
19 | 18 |
|
20 | 19 | #include "graph/backend/dnnl/patterns/fusions.hpp" |
21 | 20 | #include "graph/backend/dnnl/patterns/pattern_matcher_pass.hpp" |
@@ -87,7 +86,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp) |
87 | 86 | in_edges_t {in_edge(0, pre_tc, 0)}); |
88 | 87 | }) |
89 | 88 | .set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr { |
90 | | - return std::make_shared<gated_mlp_base_t>(); |
| 89 | + return std::make_shared<gated_mlp_base_t<false>>(); |
91 | 90 | }); |
92 | 91 |
|
93 | 92 | // gated mlp with swish decomposed to sigmoid and multiply. |
@@ -131,7 +130,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp_v1) |
131 | 130 | in_edges_t {in_edge(0, pre_tc, 0)}); |
132 | 131 | }) |
133 | 132 | .set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr { |
134 | | - return std::make_shared<gated_mlp_base_t>(); |
| 133 | + return std::make_shared<gated_mlp_base_t<false>>(); |
135 | 134 | }); |
136 | 135 |
|
137 | 136 | /* |
@@ -195,7 +194,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp) |
195 | 194 | pgraph->append_op(graph::op_kind::MatMul, fc_down_edges); |
196 | 195 | }) |
197 | 196 | .set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr { |
198 | | - return std::make_shared<larger_partition_kernel_t>(); |
| 197 | + return std::make_shared<gated_mlp_base_t<true>>(); |
199 | 198 | }); |
200 | 199 |
|
201 | 200 | // quantized gated mlp with swish decomposed to sigmoid and multiply. |
@@ -246,7 +245,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp_v1) |
246 | 245 | pgraph->append_op(graph::op_kind::MatMul, fc_down_edges); |
247 | 246 | }) |
248 | 247 | .set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr { |
249 | | - return std::make_shared<larger_partition_kernel_t>(); |
| 248 | + return std::make_shared<gated_mlp_base_t<true>>(); |
250 | 249 | }); |
251 | 250 |
|
252 | 251 | DNNL_BACKEND_REGISTER_PATTERN_DEF_END |
|
0 commit comments