Skip to content

Commit dad1f01

Browse files
authored
Add verification for torch permute op (#2551)
- adds support for an optional verifier to the generated torch op tablegen (GeneratedTorchOps.td) - uses the above to add a verifier for the torch permute op. Motivation: I hit an unclear error from linalg while developing a decomposition pass for pixel_shuffle. The error would have been clearer if the problem had been detected earlier in the invalid aten.permute op. Testing: new tests added. To run added tests, from the base directory run ``` ./build/bin/llvm-lit test/Dialect/Torch/invalid.mlir ```
1 parent e81282a commit dad1f01

File tree

5 files changed

+202
-15
lines changed

5 files changed

+202
-15
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6422,51 +6422,52 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
64226422
}];
64236423
}
64246424

6425-
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
6425+
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
64266426
AllowsTypeRefinement,
6427+
HasValueSemantics,
64276428
ReadOnly
64286429
]> {
6429-
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
6430+
let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`";
64306431
let arguments = (ins
64316432
AnyTorchTensorType:$self,
6432-
AnyTorchListOfTorchIntType:$dims
6433+
Torch_IntType:$upscale_factor
64336434
);
64346435
let results = (outs
64356436
AnyTorchTensorType:$result
64366437
);
64376438
let hasCustomAssemblyFormat = 1;
64386439
let extraClassDefinition = [{
6439-
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
6440+
ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
64406441
return parseDefaultTorchOp(parser, result, 2, 1);
64416442
}
6442-
void AtenPermuteOp::print(OpAsmPrinter &printer) {
6443+
void AtenPixelShuffleOp::print(OpAsmPrinter &printer) {
64436444
printDefaultTorchOp(printer, *this, 2, 1);
64446445
}
64456446
}];
64466447
}
64476448

6448-
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
6449+
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
64496450
AllowsTypeRefinement,
6450-
HasValueSemantics,
64516451
ReadOnly
64526452
]> {
6453-
let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`";
6453+
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
64546454
let arguments = (ins
64556455
AnyTorchTensorType:$self,
6456-
Torch_IntType:$upscale_factor
6456+
AnyTorchListOfTorchIntType:$dims
64576457
);
64586458
let results = (outs
64596459
AnyTorchTensorType:$result
64606460
);
64616461
let hasCustomAssemblyFormat = 1;
64626462
let extraClassDefinition = [{
6463-
ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
6463+
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
64646464
return parseDefaultTorchOp(parser, result, 2, 1);
64656465
}
6466-
void AtenPixelShuffleOp::print(OpAsmPrinter &printer) {
6466+
void AtenPermuteOp::print(OpAsmPrinter &printer) {
64676467
printDefaultTorchOp(printer, *this, 2, 1);
64686468
}
64696469
}];
6470+
let hasVerifier = 1;
64706471
}
64716472

64726473
def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,6 +2859,96 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
28592859
return success();
28602860
}
28612861

2862+
LogicalResult AtenPermuteOp::verify() {
2863+
2864+
// Verification of the permute op for input & output dimensions with
2865+
// statically known sizes.
2866+
2867+
SmallVector<Value> permutation;
2868+
auto permutationObtained = getListConstructElements(getDims(), permutation);
2869+
if (!permutationObtained) {
2870+
return success();
2871+
}
2872+
2873+
auto outType = getResult().getType().cast<BaseTensorType>();
2874+
auto inType = getSelf().getType().cast<BaseTensorType>();
2875+
2876+
if (!outType.hasSizes() || !inType.hasSizes()) {
2877+
return success();
2878+
}
2879+
2880+
auto outShape = outType.getSizes();
2881+
auto inShape = inType.getSizes();
2882+
2883+
auto outRank = outShape.size();
2884+
2885+
if (outRank != inShape.size()) {
2886+
return emitOpError(
2887+
"expected input and output tensors to have same rank, but ")
2888+
<< inShape.size() << " != " << outRank << '.';
2889+
}
2890+
2891+
if (outRank != permutation.size()) {
2892+
return emitOpError() << "expected permutation to have size equal result "
2893+
"tensor rank. The permutation has "
2894+
<< permutation.size()
2895+
<< " elements, the output has rank " << outRank << '.';
2896+
}
2897+
2898+
2899+
// Initialization of the reverse permutation. -1 denotes an unknown
2900+
// permutation index.
2901+
SmallVector<int64_t> reversePermutation(outRank, -1);
2902+
2903+
// In this loop:
2904+
// (1) check that the permutation indices are in bounds, and not duplicated.
2905+
// (2) populate reversePermutation (to check for duplicates).
2906+
// (3) check that the input and output shapes agree with the permutation. For
2907+
// example, if the permutation is (1,2,0) and the input shape is (2,3,5),
2908+
// then the output shape must be (3,5,2).
2909+
2910+
for (uint64_t to = 0; to < outRank; ++to) {
2911+
int64_t from;
2912+
2913+
auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from));
2914+
2915+
if (!fromIsSet) {
2916+
continue;
2917+
}
2918+
2919+
// if 'from' is the unkwown index, continue.
2920+
if (from == -1) {
2921+
continue;
2922+
}
2923+
2924+
if (!isValidDim(from, outRank)) {
2925+
return emitError("observed invalid index in permutation (")
2926+
<< from << ") for input tensor of rank " << outRank << '.';
2927+
}
2928+
2929+
if (reversePermutation[from] != -1) {
2930+
return emitOpError("has a duplicate dimension (")
2931+
<< from << ") in its permutation " << getDims() << '.';
2932+
}
2933+
reversePermutation[from] = to;
2934+
2935+
auto dimSizesDefined =
2936+
inShape[from] != kUnknownSize && outShape[to] != kUnknownSize;
2937+
auto dimSizesDifferent = inShape[from] != outShape[to];
2938+
2939+
if (dimSizesDefined && dimSizesDifferent) {
2940+
return emitOpError("has a permutation which is not compatible with the "
2941+
"input and output shapes. ")
2942+
<< "The input shape in dimension " << from << " is "
2943+
<< inShape[from] << ", and the output shape in dimension " << to
2944+
<< " is " << outShape[to]
2945+
<< " : they should be the same with this permutation. ";
2946+
}
2947+
}
2948+
2949+
return success();
2950+
}
2951+
28622952
//===----------------------------------------------------------------------===//
28632953
// DtypeCalculateYieldDtypesOp
28642954
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _get_main_module_name() -> str:
114114
def raw_emit_op(operator: JitOperator,
115115
emitter_td: TextEmitter,
116116
*, traits: List[str],
117-
has_folder: bool, has_canonicalizer: bool):
117+
has_folder: bool, has_canonicalizer: bool, has_verifier: bool):
118118
"""Emit the ODS for a JitOperator to a textual file.
119119
120120
This is the lowest level of emission and is responsible for low-level
@@ -199,6 +199,8 @@ def generic_result_name(i):
199199
p_td("let hasFolder = 1;")
200200
if has_canonicalizer:
201201
p_td("let hasCanonicalizer = 1;")
202+
if has_verifier:
203+
p_td("let hasVerifier = 1;")
202204
p_td("}")
203205
p_td("\n")
204206

@@ -208,7 +210,8 @@ def emit_op(operator: JitOperator,
208210
*,
209211
traits: Optional[List[str]] = None,
210212
has_folder: bool = False,
211-
has_canonicalizer: bool = False):
213+
has_canonicalizer: bool = False,
214+
has_verifier: bool = False):
212215
"""Main entry point for op emission.
213216
214217
Besides emitting the op, it deduces / adds traits based on the operator
@@ -228,7 +231,8 @@ def emit_op(operator: JitOperator,
228231
emitter_td,
229232
traits=traits,
230233
has_folder=has_folder,
231-
has_canonicalizer=has_canonicalizer)
234+
has_canonicalizer=has_canonicalizer,
235+
has_verifier=has_verifier)
232236

233237

234238
def emit_ops(emitter_td: TextEmitter, registry: Registry):
@@ -481,8 +485,8 @@ def emit_with_mutating_variants(key, **kwargs):
481485
emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)")
482486
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
483487
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
484-
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
485488
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
489+
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
486490
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
487491
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
488492
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")

test/Dialect/Torch/invalid.mlir

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,84 @@ func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*,
281281
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64>
282282
return %0 : !torch.vtensor<*,f64>
283283
}
284+
285+
286+
// -----
287+
288+
func.func @torch.permute$test_changing_rank (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
289+
290+
%int0 = torch.constant.int 0
291+
%int1 = torch.constant.int 1
292+
%int2 = torch.constant.int 2
293+
294+
%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
295+
296+
// expected-error@+1 {{expected input and output tensors to have same rank, but 3 != 4}}
297+
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
298+
299+
return %3 : !torch.vtensor<[1,2,3,4],f32>
300+
}
301+
302+
// -----
303+
304+
func.func @torch.permute$test_permutation_too_short (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
305+
306+
%int0 = torch.constant.int 0
307+
%int1 = torch.constant.int 1
308+
309+
%perm = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
310+
311+
// expected-error@+1 {{The permutation has 2 elements, the output has rank 3}}
312+
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
313+
314+
return %3 : !torch.vtensor<[1,2,3],f32>
315+
}
316+
317+
// -----
318+
319+
func.func @torch.permute$duplicate_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[2,3,1],f32> {
320+
321+
%int1 = torch.constant.int 1
322+
%int2 = torch.constant.int 2
323+
%perm = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
324+
325+
// expected-error@+1 {{'torch.aten.permute' op has a duplicate dimension (1) in its permutation}}
326+
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1],f32>
327+
328+
return %3 : !torch.vtensor<[2,3,1],f32>
329+
}
330+
331+
// -----
332+
333+
func.func @torch.permute$incorrect_output_shape (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[3,1,2],f32> {
334+
335+
%int0 = torch.constant.int 0
336+
%int1 = torch.constant.int 1
337+
%int2 = torch.constant.int 2
338+
%none = torch.constant.none
339+
340+
%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
341+
342+
// expected-error@+1 {{'torch.aten.permute' op has a permutation which is not compatible with the input and output shapes. The input shape in dimension 1 is 2, and the output shape in dimension 0 is 3 : they should be the same with this permutation.}}
343+
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[3,1,2],f32>
344+
345+
return %3 : !torch.vtensor<[3,1,2],f32>
346+
}
347+
348+
349+
// -----
350+
351+
func.func @torch.permute$invalid_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
352+
353+
%int0 = torch.constant.int 0
354+
%int1 = torch.constant.int 1
355+
%int7 = torch.constant.int 7
356+
%perm = torch.prim.ListConstruct %int0, %int1, %int7 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
357+
358+
359+
// expected-error@+1 {{observed invalid index in permutation (7) for input tensor of rank 3.}}
360+
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
361+
362+
return %3 : !torch.vtensor<[1,2,3],f32>
363+
}
364+

test/Dialect/Torch/ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,14 @@ func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,5
170170
%arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>>
171171
return %arg2 : !torch.list<vtensor<[1,?,56,96],f16>>
172172
}
173+
174+
// Check that verification passes with '-1' as a permutation index.
175+
func.func @torch.permute$negative_index_valid (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
176+
%intm1 = torch.constant.int -1
177+
%int0 = torch.constant.int 0
178+
%int1 = torch.constant.int 1
179+
%perm = torch.prim.ListConstruct %int0, %int1, %intm1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
180+
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
181+
return %3 : !torch.vtensor<[1,2,3],f32>
182+
}
183+

0 commit comments

Comments
 (0)