@@ -124,7 +124,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
124124 auto rewrite = IRMatcher::rewriter (IRMatcher::h_add (value, lanes), op->type );
125125 if (rewrite (h_add (x * broadcast (y, arg_lanes), lanes), h_add (x, lanes) * broadcast (y, lanes)) ||
126126 rewrite (h_add (broadcast (x, arg_lanes) * y, lanes), h_add (y, lanes) * broadcast (x, lanes)) ||
127- rewrite (h_add (broadcast (x, arg_lanes), lanes), broadcast (x * factor, lanes))) {
127+ rewrite (h_add (broadcast (x, arg_lanes), lanes), broadcast (x * factor, lanes)) ||
128+ rewrite (h_add (broadcast (x, c0), lanes), broadcast (h_add (x, lanes / c0), c0), lanes % c0 == 0 ) ||
129+ rewrite (h_add (broadcast (x, c0), lanes), broadcast (h_add (x, 1 ) * (c0 / lanes), lanes), c0 % lanes == 0 ) ||
130+ false ) {
128131 return mutate (rewrite.result , info);
129132 }
130133 break ;
@@ -136,8 +139,9 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
136139 rewrite (h_min (max (x, broadcast (y, arg_lanes)), lanes), max (h_min (x, lanes), broadcast (y, lanes))) ||
137140 rewrite (h_min (max (broadcast (x, arg_lanes), y), lanes), max (h_min (y, lanes), broadcast (x, lanes))) ||
138141 rewrite (h_min (broadcast (x, arg_lanes), lanes), broadcast (x, lanes)) ||
139- rewrite (h_min (broadcast (x, c0), lanes), h_min (x, lanes), factor % c0 == 0 ) ||
140- rewrite (h_min (ramp (x, y, arg_lanes), lanes), x + min (y * (arg_lanes - 1 ), 0 )) ||
142+ rewrite (h_min (broadcast (x, c0), 1 ), h_min (x, 1 )) ||
143+ rewrite (h_min (broadcast (x, c0), lanes), broadcast (h_min (x, lanes / c0), c0), lanes % c0 == 0 ) ||
144+ (lanes == 1 && rewrite (h_min (ramp (x, y, arg_lanes), lanes), x + min (y * (arg_lanes - 1 ), 0 ))) ||
141145 false ) {
142146 return mutate (rewrite.result , info);
143147 }
@@ -150,8 +154,9 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
150154 rewrite (h_max (max (x, broadcast (y, arg_lanes)), lanes), max (h_max (x, lanes), broadcast (y, lanes))) ||
151155 rewrite (h_max (max (broadcast (x, arg_lanes), y), lanes), max (h_max (y, lanes), broadcast (x, lanes))) ||
152156 rewrite (h_max (broadcast (x, arg_lanes), lanes), broadcast (x, lanes)) ||
153- rewrite (h_max (broadcast (x, c0), lanes), h_max (x, lanes), factor % c0 == 0 ) ||
154- rewrite (h_max (ramp (x, y, arg_lanes), lanes), x + max (y * (arg_lanes - 1 ), 0 )) ||
157+ rewrite (h_max (broadcast (x, c0), 1 ), h_max (x, 1 )) ||
158+ rewrite (h_max (broadcast (x, c0), lanes), broadcast (h_max (x, lanes / c0), c0), lanes % c0 == 0 ) ||
159+ (lanes == 1 && rewrite (h_max (ramp (x, y, arg_lanes), lanes), x + max (y * (arg_lanes - 1 ), 0 ))) ||
155160 false ) {
156161 return mutate (rewrite.result , info);
157162 }
@@ -164,15 +169,16 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
164169 rewrite (h_and (x && broadcast (y, arg_lanes), lanes), h_and (x, lanes) && broadcast (y, lanes)) ||
165170 rewrite (h_and (broadcast (x, arg_lanes) && y, lanes), h_and (y, lanes) && broadcast (x, lanes)) ||
166171 rewrite (h_and (broadcast (x, arg_lanes), lanes), broadcast (x, lanes)) ||
167- rewrite (h_and (broadcast (x, c0), lanes), h_and (x, lanes), factor % c0 == 0 ) ||
168- rewrite (h_and (ramp (x, y, arg_lanes) < broadcast (z, arg_lanes), lanes),
169- x + max (y * (arg_lanes - 1 ), 0 ) < z) ||
170- rewrite (h_and (ramp (x, y, arg_lanes) <= broadcast (z, arg_lanes), lanes),
171- x + max (y * (arg_lanes - 1 ), 0 ) <= z) ||
172- rewrite (h_and (broadcast (x, arg_lanes) < ramp (y, z, arg_lanes), lanes),
173- x < y + min (z * (arg_lanes - 1 ), 0 )) ||
174- rewrite (h_and (broadcast (x, arg_lanes) < ramp (y, z, arg_lanes), lanes),
175- x <= y + min (z * (arg_lanes - 1 ), 0 )) ||
172+ rewrite (h_and (broadcast (x, c0), lanes), broadcast (h_and (x, lanes / c0), c0), lanes % c0 == 0 ) ||
173+ rewrite (h_and (broadcast (x, c0), lanes), broadcast (h_and (x, 1 ), lanes), c0 >= lanes) ||
174+ (lanes == 1 && rewrite (h_and (ramp (x, y, arg_lanes) < broadcast (z, arg_lanes), lanes),
175+ x + max (y * (arg_lanes - 1 ), 0 ) < z)) ||
176+ (lanes == 1 && rewrite (h_and (ramp (x, y, arg_lanes) <= broadcast (z, arg_lanes), lanes),
177+ x + max (y * (arg_lanes - 1 ), 0 ) <= z)) ||
178+ (lanes == 1 && rewrite (h_and (broadcast (x, arg_lanes) < ramp (y, z, arg_lanes), lanes),
179+ x < y + min (z * (arg_lanes - 1 ), 0 ))) ||
180+ (lanes == 1 && rewrite (h_and (broadcast (x, arg_lanes) < ramp (y, z, arg_lanes), lanes),
181+ x <= y + min (z * (arg_lanes - 1 ), 0 ))) ||
176182 false ) {
177183 return mutate (rewrite.result , info);
178184 }
@@ -185,7 +191,8 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
185191 rewrite (h_or (x && broadcast (y, arg_lanes), lanes), h_or (x, lanes) && broadcast (y, lanes)) ||
186192 rewrite (h_or (broadcast (x, arg_lanes) && y, lanes), h_or (y, lanes) && broadcast (x, lanes)) ||
187193 rewrite (h_or (broadcast (x, arg_lanes), lanes), broadcast (x, lanes)) ||
188- rewrite (h_or (broadcast (x, c0), lanes), h_or (x, lanes), factor % c0 == 0 ) ||
194+ rewrite (h_or (broadcast (x, c0), lanes), broadcast (h_or (x, lanes / c0), c0), lanes % c0 == 0 ) ||
195+ rewrite (h_or (broadcast (x, c0), lanes), broadcast (h_or (x, 1 ), lanes), c0 >= lanes) ||
189196 // type of arg_lanes is somewhat indeterminate
190197 rewrite (h_or (ramp (x, y, arg_lanes) < broadcast (z, arg_lanes), lanes),
191198 x + min (y * (arg_lanes - 1 ), 0 ) < z) ||
0 commit comments