Skip to content

Commit 714b7fc

Browse files
committed
Fix casting for arith.cmpi operands to be of same type.
1 parent 8a982db commit 714b7fc

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,12 +421,15 @@ static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
421421
Value input, int64_t dim) {
422422
// performs the operation : index = index % maxIndex to wrap index around
423423
// maxIndex
424-
Value maxIndexValue = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim));
425-
Value isBeyondMaxIndices = b.create<arith::CmpIOp>(
424+
Value maxIndexValue = getDimOp(b, loc, input, dim);
425+
maxIndexValue =
426+
b.createOrFold<arith::IndexCastOp>(loc, index.getType(), maxIndexValue);
427+
Value isBeyondMaxIndices = b.createOrFold<arith::CmpIOp>(
426428
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
427-
Value wrappedIndices = b.create<arith::RemSIOp>(loc, index, maxIndexValue);
428-
return b.create<arith::SelectOp>(loc, isBeyondMaxIndices, wrappedIndices,
429-
index);
429+
Value wrappedIndices =
430+
b.createOrFold<arith::RemSIOp>(loc, index, maxIndexValue);
431+
return b.createOrFold<arith::SelectOp>(loc, isBeyondMaxIndices,
432+
wrappedIndices, index);
430433
}
431434

432435
namespace {

0 commit comments

Comments
 (0)