Skip to content

Commit 2f810de

Browse files
mcourteauxabadamsGemini 3.1 Pro
committed
fix(simplifier): VectorReduce(Broadcast(Vector)) used wrong semantics of Broadcast(Vector).
fix(simplifier): Cherry-picked fix by @abadams in #8629. Co-authored-by: Andrew Adams <andrew.b.adams@gmail.com> Co-authored-by: Gemini 3.1 Pro <gemini@aistudio.google.com>
1 parent db9fcd8 commit 2f810de

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

src/Simplify_Exprs.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
124124
auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type);
125125
if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) ||
126126
rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes)) ||
127-
rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes))) {
127+
rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes)) ||
128+
rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, lanes / c0), c0), lanes % c0 == 0) ||
129+
rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, 1) * (c0 / lanes), lanes), c0 % lanes == 0) ||
130+
false) {
128131
return mutate(rewrite.result, info);
129132
}
130133
break;
@@ -136,8 +139,9 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
136139
rewrite(h_min(max(x, broadcast(y, arg_lanes)), lanes), max(h_min(x, lanes), broadcast(y, lanes))) ||
137140
rewrite(h_min(max(broadcast(x, arg_lanes), y), lanes), max(h_min(y, lanes), broadcast(x, lanes))) ||
138141
rewrite(h_min(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
139-
rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) ||
140-
rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) ||
142+
rewrite(h_min(broadcast(x, c0), 1), h_min(x, 1)) ||
143+
rewrite(h_min(broadcast(x, c0), lanes), broadcast(h_min(x, lanes / c0), c0), lanes % c0 == 0) ||
144+
(lanes == 1 && rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0))) ||
141145
false) {
142146
return mutate(rewrite.result, info);
143147
}
@@ -150,8 +154,9 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
150154
rewrite(h_max(max(x, broadcast(y, arg_lanes)), lanes), max(h_max(x, lanes), broadcast(y, lanes))) ||
151155
rewrite(h_max(max(broadcast(x, arg_lanes), y), lanes), max(h_max(y, lanes), broadcast(x, lanes))) ||
152156
rewrite(h_max(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
153-
rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) ||
154-
rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) ||
157+
rewrite(h_max(broadcast(x, c0), 1), h_max(x, 1)) ||
158+
rewrite(h_max(broadcast(x, c0), lanes), broadcast(h_max(x, lanes / c0), c0), lanes % c0 == 0) ||
159+
(lanes == 1 && rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0))) ||
155160
false) {
156161
return mutate(rewrite.result, info);
157162
}
@@ -164,15 +169,16 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
164169
rewrite(h_and(x && broadcast(y, arg_lanes), lanes), h_and(x, lanes) && broadcast(y, lanes)) ||
165170
rewrite(h_and(broadcast(x, arg_lanes) && y, lanes), h_and(y, lanes) && broadcast(x, lanes)) ||
166171
rewrite(h_and(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
167-
rewrite(h_and(broadcast(x, c0), lanes), h_and(x, lanes), factor % c0 == 0) ||
168-
rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
169-
x + max(y * (arg_lanes - 1), 0) < z) ||
170-
rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
171-
x + max(y * (arg_lanes - 1), 0) <= z) ||
172-
rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
173-
x < y + min(z * (arg_lanes - 1), 0)) ||
174-
rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
175-
x <= y + min(z * (arg_lanes - 1), 0)) ||
172+
rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, lanes / c0), c0), lanes % c0 == 0) ||
173+
rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, 1), lanes), c0 >= lanes) ||
174+
(lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
175+
x + max(y * (arg_lanes - 1), 0) < z)) ||
176+
(lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
177+
x + max(y * (arg_lanes - 1), 0) <= z)) ||
178+
(lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
179+
x < y + min(z * (arg_lanes - 1), 0))) ||
180+
(lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
181+
x <= y + min(z * (arg_lanes - 1), 0))) ||
176182
false) {
177183
return mutate(rewrite.result, info);
178184
}
@@ -185,7 +191,8 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
185191
rewrite(h_or(x && broadcast(y, arg_lanes), lanes), h_or(x, lanes) && broadcast(y, lanes)) ||
186192
rewrite(h_or(broadcast(x, arg_lanes) && y, lanes), h_or(y, lanes) && broadcast(x, lanes)) ||
187193
rewrite(h_or(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
188-
rewrite(h_or(broadcast(x, c0), lanes), h_or(x, lanes), factor % c0 == 0) ||
194+
rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, lanes / c0), c0), lanes % c0 == 0) ||
195+
rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, 1), lanes), c0 >= lanes) ||
189196
// type of arg_lanes is somewhat indeterminate
190197
rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
191198
x + min(y * (arg_lanes - 1), 0) < z) ||

test/correctness/simplify.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,11 +805,11 @@ void check_vectors() {
805805
check(VectorReduce::make(VectorReduce::And, Broadcast::make(bool_vector, 4), 1),
806806
VectorReduce::make(VectorReduce::And, bool_vector, 1));
807807
check(VectorReduce::make(VectorReduce::Or, Broadcast::make(bool_vector, 4), 2),
808-
VectorReduce::make(VectorReduce::Or, bool_vector, 2));
808+
Broadcast::make(VectorReduce::make(VectorReduce::Or, bool_vector, 1), 2));
809809
check(VectorReduce::make(VectorReduce::Min, Broadcast::make(int_vector, 4), 4),
810-
int_vector);
810+
Broadcast::make(VectorReduce::make(VectorReduce::Min, int_vector, 1), 4));
811811
check(VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8),
812-
VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8));
812+
Broadcast::make(VectorReduce::make(VectorReduce::Max, int_vector, 2), 4));
813813

814814
{
815815
// h_add(broadcast(x, 8), 4) should simplify to broadcast(x * 2, 4)

0 commit comments

Comments
 (0)