Skip to content

[mlir][tosa] Fix invalid data type combinations check #150066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2025

Conversation

lhutton1
Copy link
Contributor

Previously this check assumed that if an operator exists in profile complimance (TosaProfileComplianceData.h.inc), an entry exists in both the profiles and extensions section. However, this is not necessarily the case.

This commit changes the check such that it doesn't assume the above. In doing so, it allows more operators to be checked for invalid data type combinations, which were otherwise skipped previously.

@llvmbot
Copy link
Member

llvmbot commented Jul 22, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

Previously this check assumed that if an operator exists in profile complimance (TosaProfileComplianceData.h.inc), an entry exists in both the profiles and extensions section. However, this is not necessarily the case.

This commit changes the check such that it doesn't assume the above. In doing so, it allows more operators to be checked for invalid data type combinations, which were otherwise skipped previously.


Full diff: https://github.com/llvm/llvm-project/pull/150066.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+9-2)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+16)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f3650ca01..09f09487129ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,16 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
   CheckCondition condition = CheckCondition::invalid;
   const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
   const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+  if (failed(maybeProfDef) && failed(maybeExtDef))
+    return success();
+
+  bool hasEntry = false;
+  if (succeeded(maybeProfDef))
+    hasEntry |= maybeProfDef.value().size();
+  if (succeeded(maybeExtDef))
+    hasEntry |= maybeExtDef.value().size();
 
-  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
-      !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+  if (!hasEntry) {
     std::string message;
     llvm::raw_string_ostream os(message);
     os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 5a424c41775c9..95ebe0dfef0f7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2027,3 +2027,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
   %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
   return %0 : tensor<2x52x3xf32>
 }
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+  // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+  return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+  %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+  return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26fb1f85..ef4e0ae13efa7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
 
 // -----
 
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
   // expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
+  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
 }
 
 // -----

@lhutton1
Copy link
Contributor Author

Copy link
Contributor

@psunn psunn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in commit message:
complimance -> compliance

Previously this check assumed that if an operator exists in profile
compliance (TosaProfileComplianceData.h.inc), an entry exists in both
the profiles and extensions section. However, this is not necessarily
the case.

This commit changes the check such that it doesn't assume the above.
In doing so, it allows more operators to be checked for invalid data
type combinations, which were otherwise skipped previously.

Change-Id: I2a7bc9be167463d29bf5d9ab1de946c26594845e
Copy link
Contributor

@psunn psunn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@lhutton1 lhutton1 merged commit fbdf4ec into llvm:main Jul 30, 2025
9 checks passed
@lhutton1 lhutton1 deleted the fix-dtype-check branch July 30, 2025 12:24
krishna2803 pushed a commit to krishna2803/llvm-project that referenced this pull request Aug 12, 2025
Previously this check assumed that if an operator exists in profile
complimance (TosaProfileComplianceData.h.inc), an entry exists in both
the profiles and extensions section. However, this is not necessarily
the case.

This commit changes the check such that it doesn't assume the above. In
doing so, it allows more operators to be checked for invalid data type
combinations, which were otherwise skipped previously.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants