Skip to content

Commit 860be09

Browse files
Elide dynamic broadcast checks when in strict symbolic shapes mode. (#2496)
When importing dynamic shaped programs from Dynamo, via torch.compile or torch.export, we can assume that strict symbolic shape checks have been done prior to generating torch IR. Among other shape checking, this eliminates the case where an unknown dimension can be dynamically '1' in a way that signals a broadcast. Adds a `isAssumingStrictSymbolicShapes` utility which consults a `torch.assume_strict_symbolic_shapes` attribute on an enclosing scope and returns true if present. In the linalg pipeline, many runtime checks are elided when this returns true.
1 parent 4e1dd3b commit 860be09

File tree

10 files changed

+126
-64
lines changed

10 files changed

+126
-64
lines changed

include/torch-mlir/Dialect/Torch/Utils/Utils.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
8686
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
8787
Value input, Value dim);
8888

89+
// In Dynamo import paths, we can assume that dynamic dimensions are strictly
90+
// quantities and are not ambiguous with '1' symbols that can be interpreted
91+
// to signal an expansion in various broadcasting scenarios. In the
92+
// torch.compile eager path, this precondition is assured by guards on 0/1
93+
// dimension values, and on the torch.export graph-capture path, the shape
94+
// solver guarantees this.
95+
//
96+
// We let lowerings assume this on a per-scope basis if the
97+
// torch.assume_strict_symbolic_shapes unit attribute is present on any parent
98+
// of the block.
99+
bool isAssumingStrictSymbolicShapes(Block *scope);
100+
101+
// Helper that uses the block from an OpBuilder for determining whether we
102+
// are assuming strict symbolic shapes.
103+
inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) {
104+
return isAssumingStrictSymbolicShapes(builder.getBlock());
105+
}
106+
89107
} // namespace Torch
90108
} // namespace torch
91109
} // namespace mlir

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -656,20 +656,23 @@ class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
656656
reassociation[0].push_back(headOnesCount++);
657657
}
658658

659-
// TODO: Add support for size-1 dynamic dimensions.
660659
Value one = rewriter.create<arith::ConstantOp>(
661660
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
662661
int64_t j = -1;
662+
bool elideDynamicBroadcastDimCheck =
663+
isAssumingStrictSymbolicShapes(rewriter);
663664
for (auto i : llvm::seq<int64_t>(headOnesCount, inputRank)) {
664665
if (inputType.isDynamicDim(i)) {
665-
// Make sure that size-1 dynamic dimension does not exist.
666-
Value dimSize = getDimOp(rewriter, loc, input, i);
667-
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
668-
loc, arith::CmpIPredicate::ne, dimSize, one);
669-
rewriter.create<cf::AssertOp>(
670-
loc, dimSizeNotOne,
671-
rewriter.getStringAttr(
672-
"unimplemented: size 1 dynamic dimension is not supported"));
666+
if (!elideDynamicBroadcastDimCheck) {
667+
// Make sure that size-1 dynamic dimension does not exist.
668+
Value dimSize = getDimOp(rewriter, loc, input, i);
669+
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
670+
loc, arith::CmpIPredicate::ne, dimSize, one);
671+
rewriter.create<cf::AssertOp>(
672+
loc, dimSizeNotOne,
673+
rewriter.getStringAttr(
674+
"unimplemented: size 1 dynamic dimension is not supported"));
675+
}
673676
++j;
674677
} else if (inputType.getDimSize(i) != 1) {
675678
++j;

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,14 +644,16 @@ class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern<AtenIndexT
644644
return rewriter.notifyMatchFailure(
645645
op,
646646
"unimplemented: index tensors with overlapping dynamic dims");
647-
if (staticDimSize > 1) {
648-
Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize,
649-
rewriter.getIndexType());
650-
auto equalToRunning = rewriter.create<arith::CmpIOp>(
651-
loc, arith::CmpIPredicate::eq, cstStaticDimSize,
652-
dynamicDims[0]);
653-
rewriter.create<cf::AssertOp>(loc, equalToRunning,
654-
"mismatched size for broadcast");
647+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
648+
if (staticDimSize > 1) {
649+
Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize,
650+
rewriter.getIndexType());
651+
auto equalToRunning = rewriter.create<arith::CmpIOp>(
652+
loc, arith::CmpIPredicate::eq, cstStaticDimSize,
653+
dynamicDims[0]);
654+
rewriter.create<cf::AssertOp>(loc, equalToRunning,
655+
"mismatched size for broadcast");
656+
}
655657
}
656658
broadcastedIndexShape.push_back(dynamicDims[0]);
657659
} else {

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
5858
}
5959

6060
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
61-
Value lhsDim1 = rewriter.create<tensor::DimOp>(loc, lhs, 1);
62-
Value rhsDim0 = rewriter.create<tensor::DimOp>(loc, rhs, 0);
6361
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
64-
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
65-
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
66-
rewriter.create<cf::AssertOp>(
67-
loc, contractingDimEqual,
68-
rewriter.getStringAttr(
69-
"mismatching contracting dimension for torch.aten.mm"));
62+
63+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
64+
Value lhsDim1 = rewriter.create<tensor::DimOp>(loc, lhs, 1);
65+
Value rhsDim0 = rewriter.create<tensor::DimOp>(loc, rhs, 0);
66+
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
67+
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
68+
rewriter.create<cf::AssertOp>(
69+
loc, contractingDimEqual,
70+
rewriter.getStringAttr(
71+
"mismatching contracting dimension for torch.aten.mm"));
72+
}
7073

7174
Type newResultType = getTypeConverter()->convertType(op.getType());
7275
Type elementType = newResultType.cast<TensorType>().getElementType();

lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class ConvertAtenSizeIntOp : public OpConversionPattern<AtenSizeIntOp> {
4242
Value inputRank = rewriter.create<arith::ConstantOp>(
4343
loc, rewriter.getI64IntegerAttr(type.getRank()));
4444
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
45-
assertIsValidDim(rewriter, loc, dimPositive, inputRank);
45+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
46+
assertIsValidDim(rewriter, loc, dimPositive, inputRank);
47+
}
4648
Value size = rewriter.create<tensor::DimOp>(
4749
loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive));
4850
rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size));

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,10 +1481,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
14811481
rewriter.getStringAttr(
14821482
"expect the size of dim 0 equal to the number of features"));
14831483
};
1484-
contractingDim0EqualsNumFeatures(weight);
1485-
contractingDim0EqualsNumFeatures(bias);
1486-
contractingDim0EqualsNumFeatures(runningMean);
1487-
contractingDim0EqualsNumFeatures(runningVar);
1484+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
1485+
contractingDim0EqualsNumFeatures(weight);
1486+
contractingDim0EqualsNumFeatures(bias);
1487+
contractingDim0EqualsNumFeatures(runningMean);
1488+
contractingDim0EqualsNumFeatures(runningVar);
1489+
}
14881490

14891491
auto indexingMap = AffineMap::get(
14901492
/*dimCount=*/inputRank,

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
231231
// if this is the first tensor operand that didn't continue above:
232232
// take its dimension size as the size of the non-broadcasted
233233
// traversal along this dimension (this may include a dynamic size-1,
234-
// **non-broadcasted** traversal!)
234+
// **non-broadcasted** traversal unless if
235+
// isAssumingStrictSymbolicShapes!)
235236
// emit error check "if the size does not match the non-broadcasted
236237
// traversal size along this dimension, error"
237238
// ```
@@ -251,6 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
251252
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
252253
SmallVector<Value> resultShape(resultRank, c1);
253254
SmallVector<AffineMap> indexingMaps;
255+
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
254256
for (Value tensorOperand : tensorOperands) {
255257
SmallVector<AffineExpr> exprs;
256258
auto type = tensorOperand.getType().cast<RankedTensorType>();
@@ -294,11 +296,13 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
294296
// This is the check which protects against the undefined behavior of
295297
// the generated linalg op in the case of iterating two operands with
296298
// dimensions sizes that are expected to match.
297-
auto equalToRunning =
298-
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
299-
resultShape[resultDim], currentDimSize);
300-
b.create<cf::AssertOp>(loc, equalToRunning,
301-
"mismatched size for broadcast");
299+
if (!elideDynamicBroadcastCheck) {
300+
auto equalToRunning =
301+
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
302+
resultShape[resultDim], currentDimSize);
303+
b.create<cf::AssertOp>(loc, equalToRunning,
304+
"mismatched size for broadcast");
305+
}
302306
}
303307
indexingMaps.push_back(AffineMap::get(
304308
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
@@ -337,6 +341,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
337341
Type elementType = inputType.getElementType();
338342
Location loc = op->getLoc();
339343
SmallVector<Value> outShape;
344+
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter);
340345

341346
// Create affine map and shapes for tensor initialization.
342347
SmallVector<AffineExpr> outExpr;
@@ -351,12 +356,14 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
351356
Value shapeValue = broadcastToShape[i];
352357
size_t j = i - diff;
353358
if (i < diff) {
354-
Value isValid = rewriter.create<arith::CmpIOp>(
355-
loc, arith::CmpIPredicate::sge, shapeValue, zero);
356-
rewriter.create<cf::AssertOp>(
357-
loc, isValid,
358-
rewriter.getStringAttr(
359-
"negative values not allowed in new dimensions"));
359+
if (!elideDynamicBroadcastCheck) {
360+
Value isValid = rewriter.create<arith::CmpIOp>(
361+
loc, arith::CmpIPredicate::sge, shapeValue, zero);
362+
rewriter.create<cf::AssertOp>(
363+
loc, isValid,
364+
rewriter.getStringAttr(
365+
"negative values not allowed in new dimensions"));
366+
}
360367
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
361368
continue;
362369
}

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3484,11 +3484,13 @@ class DecomposeAtenAdaptiveAvgPool1dOp
34843484
: rewriter.create<Torch::ConstantIntOp>(
34853485
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
34863486
} else {
3487-
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
3488-
rewriter.create<RuntimeAssertOp>(
3489-
loc, cond,
3490-
"unimplemented: only support cases where input and output size are "
3491-
"equal for non-unit output size");
3487+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
3488+
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
3489+
rewriter.create<RuntimeAssertOp>(
3490+
loc, cond,
3491+
"unimplemented: only support cases where input and output size are "
3492+
"equal for non-unit output size");
3493+
}
34923494
kernelSize.push_back(constantOne);
34933495
}
34943496

@@ -3586,13 +3588,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp
35863588
loc, rewriter.getI64IntegerAttr(
35873589
inputShape[rank - 2 + i])));
35883590
} else {
3589-
Value cond = rewriter.create<AtenEqIntOp>(loc, inputHW[i],
3590-
outputShapeSizesTorchInt[i]);
3591-
rewriter.create<RuntimeAssertOp>(
3592-
loc, cond,
3593-
"unimplemented: only support cases where input and output size are "
3594-
"equal for non-unit output size");
3595-
3591+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
3592+
Value cond = rewriter.create<AtenEqIntOp>(
3593+
loc, inputHW[i], outputShapeSizesTorchInt[i]);
3594+
rewriter.create<RuntimeAssertOp>(loc, cond,
3595+
"unimplemented: only support cases "
3596+
"where input and output size are "
3597+
"equal for non-unit output size");
3598+
}
35963599
Value outMinusOne = rewriter.create<AtenSubIntOp>(
35973600
loc, outputShapeSizesTorchInt[i], constantOne);
35983601
kernelSize.push_back(
@@ -3822,13 +3825,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
38223825
loc, rewriter.getF64FloatAttr(correction));
38233826
// The `correction` value should be less than or equal to `productDimSize +
38243827
// 1`.
3825-
Value productDimSizePlusOne = rewriter.create<AtenAddOp>(
3826-
loc, productDimSize.getType(), productDimSize, constantOne);
3827-
Value cond =
3828-
rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne, cstCorrection);
3829-
rewriter.create<RuntimeAssertOp>(
3830-
loc, cond,
3831-
"correction value should be less than or equal to productDimSize + 1");
3828+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
3829+
Value productDimSizePlusOne = rewriter.create<AtenAddOp>(
3830+
loc, productDimSize.getType(), productDimSize, constantOne);
3831+
Value cond = rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne,
3832+
cstCorrection);
3833+
rewriter.create<RuntimeAssertOp>(
3834+
loc, cond,
3835+
"correction value should be less than or equal to productDimSize + 1");
3836+
}
38323837
Value productDimSizeSubCorrection =
38333838
rewriter.create<AtenSubFloatOp>(loc, productDimSize, cstCorrection);
38343839
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,12 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
324324
op->getLoc(), unsqueezedType, input, dim);
325325
return unsqueezed;
326326
}
327+
328+
bool Torch::isAssumingStrictSymbolicShapes(Block *block) {
329+
for (Operation *parentOp = block->getParentOp(); parentOp;
330+
parentOp = parentOp->getParentOp()) {
331+
if (parentOp->hasAttr("torch.assume_strict_symbolic_shapes"))
332+
return true;
333+
}
334+
return false;
335+
}

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
// CHECK: %[[C0:.*]] = arith.constant 0 : index
99
// CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
1010
// CHECK: %[[C1:.*]] = arith.constant 1 : index
11+
// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
12+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
1113
// CHECK: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
1214
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1315
// CHECK: %[[RHS_DIM_0:.*]] = tensor.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
14-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
15-
// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
1616
// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index
1717
// CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm"
1818
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[RHS_DIM_1]]) : tensor<?x?xf32>
@@ -29,6 +29,17 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
2929

3030
// -----
3131

32+
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
33+
// CHECK-NOT: assert
34+
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>
35+
attributes {torch.assume_strict_symbolic_shapes}
36+
{
37+
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>
38+
return %0 : !torch.vtensor<[?,2],f32>
39+
}
40+
41+
// -----
42+
3243
// If the operands are missing dtype, we cannot lower it.
3344
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
3445
// expected-error@+1 {{failed to legalize}}
@@ -264,4 +275,4 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten
264275
func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> {
265276
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
266277
return %0 : !torch.vtensor<[?,?],f16>
267-
}
278+
}

0 commit comments

Comments
 (0)