Skip to content

Commit 6f81ad7

Browse files
authored
[TorchToLinalg] Improve broadcast lowerings in strict symbolic modes (#2505)
With strict symbolic shapes, we can assume numpy-style dynamic broadcasts never occur. This improves the lowering in the presence of this assumption.
1 parent 42b6c0a commit 6f81ad7

File tree

5 files changed

+237
-39
lines changed

5 files changed

+237
-39
lines changed

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,31 +1095,35 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
10951095
// which in this case is `inShapeConverted` because this shape will yield
10961096
// us the dimension size of the output.
10971097
SmallVector<bool> useBroadcastToShape;
1098-
for (auto x : inShape) {
1098+
int64_t inputRank = self.getType().cast<RankedTensorType>().getRank();
1099+
for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e;
1100+
++i) {
10991101
int64_t dim;
1100-
if (!matchPattern(x, m_TorchConstantInt(&dim))) {
1101-
Operation *defOp = x.getDefiningOp();
1102-
if (isa<AtenSizeOp, AtenSizeIntOp>(defOp))
1103-
useBroadcastToShape.push_back(true);
1104-
else
1102+
if (matchPattern(inShape[i], m_TorchConstantInt(&dim))) {
1103+
if (dim < 0) {
11051104
useBroadcastToShape.push_back(false);
1105+
} else {
1106+
useBroadcastToShape.push_back(true);
1107+
}
11061108
} else {
1107-
useBroadcastToShape.push_back(false);
1109+
// Note: Dynamic -1 (inferred) broadcast shapes are unimplemented.
1110+
useBroadcastToShape.push_back(true);
11081111
}
11091112
}
11101113

11111114
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
11121115
rewriter, op.getLoc(), getTypeConverter(), inShape);
1116+
auto newResultType =
1117+
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
11131118
Value result;
1114-
if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self,
1115-
inShapeConverted, result,
1116-
useBroadcastToShape))) {
1119+
if (failed(torch_to_linalg::broadcastToGivenShape(
1120+
op, rewriter, self, inShapeConverted, newResultType, result,
1121+
useBroadcastToShape))) {
11171122
return rewriter.notifyMatchFailure(
11181123
op, "unable to perform broadcast operation");
11191124
}
11201125

1121-
Type newResultType = getTypeConverter()->convertType(op.getType());
1122-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
1126+
rewriter.replaceOp(op, result);
11231127
return success();
11241128
}
11251129
};
@@ -1177,7 +1181,7 @@ class ConvertAtenCopyOp : public OpConversionPattern<AtenCopyOp> {
11771181
selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]);
11781182
Value broadcastedSrc;
11791183
if (failed(torch_to_linalg::broadcastToGivenShape(
1180-
op, rewriter, src, selfSizes, broadcastedSrc))) {
1184+
op, rewriter, src, selfSizes, selfType, broadcastedSrc))) {
11811185
return rewriter.notifyMatchFailure(
11821186
op, "unable to perform broadcast operation");
11831187
}

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,24 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
295295

296296
// Broadcast the batch dimensions of both the matrices.
297297
Value broadcastedLhs, broadcastedRhs;
298+
// TODO: Improve usage of static shape information.
299+
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
300+
ShapedType::kDynamic);
301+
auto lhsBroadcastType =
302+
RankedTensorType::get(lhsTargetShape, lhsType.getElementType());
298303
if (failed(torch_to_linalg::broadcastToGivenShape(
299-
op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) {
304+
op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType,
305+
broadcastedLhs))) {
300306
return rewriter.notifyMatchFailure(
301307
op, "unable to perform broadcast operation");
302308
}
309+
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
310+
ShapedType::kDynamic);
311+
auto rhsBroadcastType =
312+
RankedTensorType::get(rhsTargetShape, rhsType.getElementType());
303313
if (failed(torch_to_linalg::broadcastToGivenShape(
304-
op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) {
314+
op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType,
315+
broadcastedRhs))) {
305316
return rewriter.notifyMatchFailure(
306317
op, "unable to perform broadcast operation");
307318
}

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 111 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -327,22 +327,28 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
327327
// Broadcasts input tensor based on the broadcastToShape.
328328
LogicalResult torch_to_linalg::broadcastToGivenShape(
329329
Operation *op, PatternRewriter &rewriter, Value input,
330-
SmallVector<Value> broadcastToShape, Value &result,
331-
SmallVector<bool> useBroadcastToShape) {
330+
SmallVector<Value> broadcastToShape, RankedTensorType broadcastType,
331+
Value &result, SmallVector<bool> useBroadcastToShape) {
332332
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
333+
int64_t inputRank = inputType.getRank();
334+
int64_t outputRank = broadcastToShape.size();
335+
ArrayRef<int64_t> outputShape = broadcastType.getShape();
333336
SmallVector<int64_t> inputShape =
334337
makeShapeTorchCompatible(inputType.getShape());
335-
if (broadcastToShape.size() < inputShape.size()) {
338+
if (outputRank < inputRank) {
336339
return rewriter.notifyMatchFailure(
337340
op, "invalid shape: broadcastToShape size must not be smaller than the "
338341
"size of the input shape");
339342
}
340343

341344
Type elementType = inputType.getElementType();
342345
Location loc = op->getLoc();
343-
SmallVector<Value> outShape;
346+
SmallVector<OpFoldResult> outShape;
344347
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter);
345348

349+
// Vector indicating broadcasted status when assuming strict symbolic shapes.
350+
SmallVector<bool> broadcastedStatus;
351+
346352
// Create affine map and shapes for tensor initialization.
347353
SmallVector<AffineExpr> outExpr;
348354
Value zero =
@@ -351,10 +357,39 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
351357
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
352358
Value oneIndex =
353359
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
354-
size_t diff = broadcastToShape.size() - inputShape.size();
355-
for (size_t i = 0; i < broadcastToShape.size(); i++) {
360+
size_t diff = outputRank - inputRank;
361+
bool hasDynamicNumpyBroadcast = false;
362+
for (size_t i = 0, e = outputRank; i < e; i++) {
356363
Value shapeValue = broadcastToShape[i];
357364
size_t j = i - diff;
365+
bool isDynamic = i >= diff && inputShape[j] == kUnknownSize;
366+
367+
// Inherit static output shapes if present.
368+
if (outputShape[i] != ShapedType::kDynamic) {
369+
outShape.push_back(rewriter.getIndexAttr(outputShape[i]));
370+
if (i < diff) {
371+
if (outputShape[i] < 0) {
372+
return rewriter.notifyMatchFailure(
373+
op, "invalid shape: negative values not allowed in new broadcast "
374+
"dimensions");
375+
}
376+
continue;
377+
}
378+
if (isDynamic) {
379+
hasDynamicNumpyBroadcast = true;
380+
} else if (inputShape[j] != outputShape[i] && inputShape[j] != 1) {
381+
return rewriter.notifyMatchFailure(
382+
op, "invalid shape: static mismatch in input and output broadcast "
383+
"shapes");
384+
}
385+
386+
// If strict symbolic shapes are assumed and the input shape is dynamic,
387+
// we can assume that dim is not broadcasted.
388+
broadcastedStatus.push_back(inputShape[j] != outputShape[i] &&
389+
!isDynamic);
390+
continue;
391+
}
392+
358393
if (i < diff) {
359394
if (!elideDynamicBroadcastCheck) {
360395
Value isValid = rewriter.create<arith::CmpIOp>(
@@ -374,24 +409,80 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
374409
Value select = rewriter.create<arith::SelectOp>(
375410
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
376411
outShape.push_back(select);
377-
} else {
378-
// Case of dynamic input dimension wherein the shape to broadcast will
379-
// yield us the dimension size of the output.
380-
Value dim = getDimOp(rewriter, loc, input, j);
381-
if (!useBroadcastToShape.empty()) {
382-
if (useBroadcastToShape[i])
383-
dim = castIntToIndex(rewriter, loc, broadcastToShape[j]);
412+
broadcastedStatus.push_back(true);
413+
continue;
414+
}
415+
416+
// Case of dynamic input dimension wherein the shape to broadcast will
417+
// yield us the dimension size of the output.
418+
Value dim;
419+
if (!useBroadcastToShape.empty() && useBroadcastToShape[j]) {
420+
dim = castIntToIndex(rewriter, loc, broadcastToShape[i]);
421+
if (isDynamic) {
422+
hasDynamicNumpyBroadcast = true;
384423
}
385-
outShape.push_back(dim);
424+
if (!elideDynamicBroadcastCheck) {
425+
Value isValid = rewriter.create<arith::CmpIOp>(
426+
loc, arith::CmpIPredicate::sge, shapeValue, zero);
427+
rewriter.create<cf::AssertOp>(
428+
loc, isValid,
429+
rewriter.getStringAttr(
430+
"unimplemented: dynamic negative broadcast sizes"));
431+
}
432+
} else {
433+
dim = getDimOp(rewriter, loc, input, j);
386434
}
435+
// We can safely assume this dimension is not broadcasted with strict
436+
// symbols.
437+
broadcastedStatus.push_back(false);
438+
outShape.push_back(dim);
387439
}
388440

389-
Value outTensor = rewriter.create<tensor::EmptyOp>(
390-
loc, getAsOpFoldResult(outShape), elementType);
441+
Value outTensor =
442+
rewriter.create<tensor::EmptyOp>(loc, outShape, elementType);
443+
444+
// If we know there are no ? -> ? broadcasted dims, or we are assuming
445+
// strict symbols, we can safely use standard linalg style broadcasting
446+
// semantics.
447+
if (!hasDynamicNumpyBroadcast || elideDynamicBroadcastCheck) {
448+
// If no dims are broadcasted and the rank doesn't change, we can just fold
449+
// the op away entirely.
450+
if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) &&
451+
inputRank == outputRank) {
452+
result = rewriter.create<tensor::CastOp>(loc, outTensor.getType(), input);
453+
return success();
454+
}
455+
456+
SmallVector<AffineExpr> inputExprs;
457+
for (int64_t i = 0, e = inputRank; i < e; ++i) {
458+
if (broadcastedStatus[i]) {
459+
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
460+
continue;
461+
}
462+
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
463+
}
464+
465+
SmallVector<AffineMap> indexingMaps = {
466+
AffineMap::get(outputRank, 0, inputExprs, rewriter.getContext()),
467+
rewriter.getMultiDimIdentityMap(outputRank)};
468+
SmallVector<utils::IteratorType> iteratorTypes(
469+
outputRank, utils::IteratorType::parallel);
470+
result = rewriter
471+
.create<linalg::GenericOp>(
472+
loc, outTensor.getType(), input, outTensor, indexingMaps,
473+
iteratorTypes,
474+
[&](OpBuilder &b, Location loc, ValueRange args) {
475+
b.create<linalg::YieldOp>(loc, args[0]);
476+
})
477+
.getResult(0);
478+
return success();
479+
}
391480

481+
// Fall back to numpy-style dynamic broadcasting in the form of a single
482+
// linalg op.
392483
SmallVector<AffineMap> indexingMaps = {
393-
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
394-
SmallVector<utils::IteratorType> iteratorTypes(broadcastToShape.size(),
484+
rewriter.getMultiDimIdentityMap(outputRank)};
485+
SmallVector<utils::IteratorType> iteratorTypes(outputRank,
395486
utils::IteratorType::parallel);
396487
result = rewriter
397488
.create<linalg::GenericOp>(
@@ -402,7 +493,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
402493
// would be used to extract values from the input tensor
403494
// later on.
404495
SmallVector<Value> loopIndices;
405-
for (size_t i = 0; i < broadcastToShape.size(); ++i) {
496+
for (size_t i = 0, e = outputRank; i < e; ++i) {
406497
if (i < diff)
407498
continue;
408499
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
@@ -411,7 +502,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
411502
// the i-th input dimension is not 1, else it contains a
412503
// zero index.
413504
SmallVector<Value> inputIndicesToExtract;
414-
for (size_t i = 0, n = inputShape.size(); i < n; i++) {
505+
for (size_t i = 0, n = inputRank; i < n; i++) {
415506
if (inputShape[i] == 1) {
416507
inputIndicesToExtract.push_back(zeroIndex);
417508
} else {

lib/Conversion/TorchToLinalg/Utils.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ Value createElementwiseLinalgGeneric(
7373
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
7474

7575
// Broadcasts input tensor based on the broadcastToShape.
76-
LogicalResult
77-
broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
78-
SmallVector<Value> broadcastToShape, Value &result,
79-
SmallVector<bool> useBroadcastToShape = {});
76+
LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter,
77+
Value input,
78+
SmallVector<Value> broadcastToShape,
79+
RankedTensorType broadcastType,
80+
Value &result,
81+
SmallVector<bool> useBroadcastToShape = {});
8082

8183
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
8284
// <?x?xf32>
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @torch.aten.broadcast_to$simple_static(
4+
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<3x4x2xf32>
5+
// CHECK: %[[GENERIC:.*]] = linalg.generic
6+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
7+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
8+
// CHECK-SAME: ins({{.*}} : tensor<4x2xf32>) outs({{.*}} : tensor<3x4x2xf32>) {
9+
// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
10+
// CHECK: linalg.yield %[[IN]] : f32
11+
// CHECK: } -> tensor<3x4x2xf32>
12+
func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
13+
%int3 = torch.constant.int 3
14+
%int4 = torch.constant.int 4
15+
%int2 = torch.constant.int 2
16+
%list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
17+
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[4,2],f32>, !torch.list<int> -> !torch.vtensor<[3,4,2],f32>
18+
return %0 : !torch.vtensor<[3,4,2],f32>
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast(
24+
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32>
25+
// CHECK: %[[GENERIC:.*]] = linalg.generic
26+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
27+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
28+
// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) {
29+
// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
30+
// CHECK: linalg.yield %[[IN]] : f32
31+
// CHECK: } -> tensor<1x4x2xf32>
32+
func.func @torch.aten.broadcast_to$static_numpy_broadcast(%arg0: !torch.vtensor<[1,1,2],f32>) -> !torch.vtensor<[1,4,2],f32> {
33+
%int1 = torch.constant.int 1
34+
%int4 = torch.constant.int 4
35+
%int2 = torch.constant.int 2
36+
%list = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
37+
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[1,1,2],f32>, !torch.list<int> -> !torch.vtensor<[1,4,2],f32>
38+
return %0 : !torch.vtensor<[1,4,2],f32>
39+
}
40+
41+
// -----
42+
43+
// CHECK-LABEL: func.func @torch.aten.broadcast_to$empty_input(
44+
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor<?xf32>
45+
// CHECK: %[[GENERIC:.*]] = linalg.generic
46+
// CHECK-SAME: indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>]
47+
// CHECK-SAME: iterator_types = ["parallel"]}
48+
// CHECK-SAME: ins({{.*}} : tensor<f32>) outs({{.*}} : tensor<?xf32>) {
49+
// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
50+
// CHECK: linalg.yield %[[IN]] : f32
51+
// CHECK: } -> tensor<?xf32>
52+
func.func @torch.aten.broadcast_to$empty_input(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> {
53+
%list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
54+
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
55+
return %0 : !torch.vtensor<[?],f32>
56+
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: func.func @torch.aten.broadcast_to$strict_dynamic_broadcast(
61+
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor<?x?xf32>
62+
// CHECK: %[[GENERIC:.*]] = linalg.generic
63+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>]
64+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
65+
// CHECK-SAME: ins({{.*}} : tensor<?xf32>) outs({{.*}} : tensor<?x?xf32>) {
66+
// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
67+
// CHECK: linalg.yield %[[IN]] : f32
68+
// CHECK: } -> tensor<?x?xf32>
69+
func.func @torch.aten.broadcast_to$strict_dynamic_broadcast(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> attributes {torch.assume_strict_symbolic_shapes} {
70+
%list = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int>
71+
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
72+
return %0 : !torch.vtensor<[?,?],f32>
73+
}
74+
75+
// -----
76+
77+
/// Nothing we can do; verify we hit the fall back path.
78+
// CHECK-LABEL: func.func @torch.aten.broadcast_to$pure_dynamic_broadcast(
79+
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor<?x?xf32>
80+
// CHECK: %[[GENERIC:.*]] = linalg.generic
81+
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>]
82+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
83+
// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
84+
// CHECK: ^bb0(%[[OUT:.+]]: f32):
85+
// CHECK: tensor.extract
86+
func.func @torch.aten.broadcast_to$pure_dynamic_broadcast(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> {
87+
%list = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int>
88+
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
89+
return %0 : !torch.vtensor<[?,?],f32>
90+
}

0 commit comments

Comments
 (0)