Skip to content

Commit 1cad353

Browse files
committed
[mlir][tosa] Fix invalid data type combinations check
Previously this check assumed that if an operator exists in profile compliance (TosaProfileComplianceData.h.inc), an entry exists in both the profiles and extensions section. However, this is not necessarily the case. This commit changes the check such that it doesn't assume the above. In doing so, it allows more operators to be checked for invalid data type combinations, which were otherwise skipped previously. Change-Id: I2a7bc9be167463d29bf5d9ab1de946c26594845e
1 parent 3b8adcf commit 1cad353

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
464464
CheckCondition condition = CheckCondition::invalid;
465465
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
466466
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
467+
if (failed(maybeProfDef) && failed(maybeExtDef))
468+
return success();
467469

468-
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
469-
!maybeProfDef.value().size() && !maybeExtDef.value().size()) {
470+
const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
471+
(succeeded(maybeExtDef) && !maybeExtDef->empty());
472+
if (!hasEntry) {
470473
std::string message;
471474
llvm::raw_string_ostream os(message);
472475
os << "illegal: operation operand/result data types did not align with any "

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,3 +2027,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
20272027
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
20282028
return %0 : tensor<2x52x3xf32>
20292029
}
2030+
2031+
// -----
2032+
2033+
func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
2034+
// expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
2035+
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
2036+
return %0 : tensor<1x12x11xf32>
2037+
}
2038+
2039+
// -----
2040+
2041+
func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
2042+
// expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
2043+
%0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
2044+
return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
2045+
}

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
4848

4949
// -----
5050

51-
func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
51+
func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
5252
// expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
53-
%0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
54-
return %0 : tensor<1x1x1x1x13x21x3xf32>
53+
%0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
54+
return %0 : tensor<1x1x1x1x13x21x3xi32>
5555
}
5656

5757
// -----

0 commit comments

Comments
 (0)