Skip to content

[mlir][math] Add clampf and clean math ExpandOps API #151153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

fabianmcg
Copy link
Contributor

This patch adds the clampf operation to the math dialect. The semantics op are defined as:

clampf(x, min_v, max_v) = max(min(x, min_v), max_v) 

The reasoning behind adding this operation is that some GPU vendors offer specialized intrinsics for this operation, or subsets of this operation. For example, __saturatef in NVIDIA GPUs, or __builtin_amdgcn_fmed3f in AMD GPUs.

This patch also removes test-expand-math in favor of math-expand-ops.
Finally, it removes individual expansion population API calls like populateExpandCoshPattern in favor of:

void populateExpansionPatterns(RewritePatternSet &patterns,
                               ArrayRef<StringRef> opMnemonics = {});

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:math labels Jul 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

Changes

This patch adds the clampf operation to the math dialect. The semantics op are defined as:

clampf(x, min_v, max_v) = max(min(x, min_v), max_v) 

The reasoning behind adding this operation is that some GPU vendors offer specialized intrinsics for this operation, or subsets of this operation. For example, __saturatef in NVIDIA GPUs, or __builtin_amdgcn_fmed3f in AMD GPUs.

This patch also removes test-expand-math in favor of math-expand-ops.
Finally, it removes individual expansion population API calls like populateExpandCoshPattern in favor of:

void populateExpansionPatterns(RewritePatternSet &amp;patterns,
                               ArrayRef&lt;StringRef&gt; opMnemonics = {});

Patch is 20.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151153.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+31)
  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+10-16)
  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.td (+20)
  • (modified) mlir/lib/Dialect/Math/Transforms/CMakeLists.txt (+1-1)
  • (renamed) mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (+77-62)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+34-1)
  • (modified) mlir/test/Dialect/Math/ops.mlir (+14-1)
  • (modified) mlir/test/lib/Dialect/Math/CMakeLists.txt (-1)
  • (removed) mlir/test/lib/Dialect/Math/TestExpandMath.cpp (-62)
  • (modified) mlir/test/mlir-runner/test-expand-math-approx.mlir (+1-1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (-2)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 56370388dea87..cfd8c4b8f11f7 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -352,6 +352,37 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ClampFOp
+//===----------------------------------------------------------------------===//
+
+def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> {
+  let summary = "floating point clamping operation";
+  let description = [{
+    The `clampf` operation takes three operands and returns one result, each of
+    these is required to be the same type. Operands must be of floating point type
+    (i.e., scalar, tensor or vector).
+
+    The semantics of the operation are described by:
+    ```
+      clampf(value, min, max) = maxf(minf(value, min), max)
+    ```
+
+    Example:
+
+    ```mlir
+    %d = math.clampf %value to [%min, %max] : f64
+    ```
+  }];
+  let arguments = (ins FloatLike:$value, FloatLike:$min, FloatLike:$max,
+      DefaultValuedAttr<Arith_FastMathAttr,
+                        "::mlir::arith::FastMathFlags::none">:$fastmath);
+  let assemblyFormat = [{
+    $value `to` ` ` `[` $min `,` $max `]` (`fastmath` `` $fastmath^)?
+    attr-dict `:` type($result)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // CopySignOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index c0fe5d3be448a..b3abbf728a3c6 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -23,22 +23,16 @@ class ConversionTarget;
 class RewritePatternSet;
 class TypeConverter;
 
-void populateExpandCtlzPattern(RewritePatternSet &patterns);
-void populateExpandTanPattern(RewritePatternSet &patterns);
-void populateExpandSinhPattern(RewritePatternSet &patterns);
-void populateExpandCoshPattern(RewritePatternSet &patterns);
-void populateExpandTanhPattern(RewritePatternSet &patterns);
-void populateExpandAsinhPattern(RewritePatternSet &patterns);
-void populateExpandAcoshPattern(RewritePatternSet &patterns);
-void populateExpandAtanhPattern(RewritePatternSet &patterns);
-void populateExpandFmaFPattern(RewritePatternSet &patterns);
-void populateExpandCeilFPattern(RewritePatternSet &patterns);
-void populateExpandExp2FPattern(RewritePatternSet &patterns);
-void populateExpandPowFPattern(RewritePatternSet &patterns);
-void populateExpandFPowIPattern(RewritePatternSet &patterns);
-void populateExpandRoundFPattern(RewritePatternSet &patterns);
-void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
-void populateExpandRsqrtPattern(RewritePatternSet &patterns);
+namespace math {
+/// Adds patterns to expand math operations into other more fundamental
+/// operations. For example, hyperbolic functions are expanded into expressions
+/// using `exp`. If `opMnemonics` is empty then all available patterns will be
+/// added, otherwise only the patterns corresponding to ops in `opMnemonics`
+/// will be added to the set.
+void populateExpansionPatterns(RewritePatternSet &patterns,
+                               ArrayRef<StringRef> opMnemonics = {});
+} // namespace math
+
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index a84c89020d4f3..4d415aeac8f58 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,4 +44,24 @@ def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
 }
 
+def MathExpandOpsPass : Pass<"math-expand-ops"> {
+  let summary = "Expand math operations.";
+  let description = [{
+    Expands some math operations into more fundamental operations, allowing them
+    to be subsequently lowered through these. For example, hyperbolic functions
+    are transformed into their expanded form containing only `exp` functions.
+
+    The `ops` parameter can be used to apply only a subset of all the
+    available expansions, these must correspond to the operation mnemonic.
+    For example, `ops=sinh,acosh` will expand only `math.sinh` and
+    `math.acosh` operations. If the list is empty, then all expansions are
+    applied.
+  }];
+  let dependentDialects = ["arith::ArithDialect"];
+  let options = [
+    ListOption<"opMnemonics", "ops", "std::string",
+               "Operations to expand.">
+  ];
+}
+
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index e1c0c2410c126..d37a056e8e158 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
-  ExpandPatterns.cpp
+  ExpandOps.cpp
   ExtendToSupportedTypes.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
similarity index 89%
rename from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
rename to mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4a40a3055ed62..cd68039d0d964 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -13,14 +13,18 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXPANDOPSPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
 /// Create a float constant.
 static Value createFloatConst(Location loc, Type type, APFloat value,
                               OpBuilder &b) {
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
   return success();
 }
 
-void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
-  patterns.add(convertCtlzOp);
-}
-
-void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertSinhOp);
-}
-
-void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
-  patterns.add(convertCoshOp);
-}
-
-void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
-  patterns.add(convertTanOp);
-}
-
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertTanhOp);
-}
-
-void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertAsinhOp);
-}
-
-void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
-  patterns.add(convertAcoshOp);
-}
-
-void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertAtanhOp);
-}
-
-void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertFmaFOp);
-}
-
-void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertCeilOp);
-}
-
-void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
-  patterns.add(convertExp2fOp);
-}
-
-void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertPowfOp);
-}
-
-void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
-  patterns.add(convertFPowIOp);
-}
-
-void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
-  patterns.add(convertRoundOp);
+// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
+static LogicalResult convertClampfOp(math::ClampFOp op,
+                                     PatternRewriter &rewriter) {
+  auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
+                                         op.getMin(), op.getFastmath());
+  rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+                                                 op.getFastmath());
+  return success();
 }
 
-void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
-  patterns.add(convertRoundEvenOp);
+void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
+                                           ArrayRef<StringRef> opMnemonics) {
+  auto filter = [&](StringRef name) {
+    // This should be a static assert and `consume_front` take a twine, but none
+    // is currently possible. TODO: augment `StringRef::consume_front` and make
+    // `getDialectNamespace` use `std::string_view`.
+    assert("math" == MathDialect::getDialectNamespace());
+    name.consume_front("math.");
+    return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
+  };
+  if (filter(CountLeadingZerosOp::getOperationName()))
+    patterns.add(convertCtlzOp);
+  if (filter(SinhOp::getOperationName()))
+    patterns.add(convertSinhOp);
+  if (filter(CoshOp::getOperationName()))
+    patterns.add(convertCoshOp);
+  if (filter(TanOp::getOperationName()))
+    patterns.add(convertTanOp);
+  if (filter(TanhOp::getOperationName()))
+    patterns.add(convertTanhOp);
+  if (filter(AsinhOp::getOperationName()))
+    patterns.add(convertAsinhOp);
+  if (filter(AcoshOp::getOperationName()))
+    patterns.add(convertAcoshOp);
+  if (filter(AtanhOp::getOperationName()))
+    patterns.add(convertAtanhOp);
+  if (filter(FmaOp::getOperationName()))
+    patterns.add(convertFmaFOp);
+  if (filter(CeilOp::getOperationName()))
+    patterns.add(convertCeilOp);
+  if (filter(Exp2Op::getOperationName()))
+    patterns.add(convertExp2fOp);
+  if (filter(PowFOp::getOperationName()))
+    patterns.add(convertPowfOp);
+  if (filter(FPowIOp::getOperationName()))
+    patterns.add(convertFPowIOp);
+  if (filter(RoundOp::getOperationName()))
+    patterns.add(convertRoundOp);
+  if (filter(RoundEvenOp::getOperationName()))
+    patterns.add(convertRoundEvenOp);
+  if (filter(RsqrtOp::getOperationName()))
+    patterns.add(convertRsqrtOp);
+  if (filter(ClampFOp::getOperationName()))
+    patterns.add(convertClampfOp);
 }
 
-void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
-  patterns.add(convertRsqrtOp);
-}
+//===----------------------------------------------------------------------===//
+// MathExpandOpsPass pass
+//===----------------------------------------------------------------------===//
+namespace {
+struct MathExpandOpsPass final
+    : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
+  using MathExpandOpsPassBase::MathExpandOpsPassBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    SmallVector<StringRef> mnemonics =
+        llvm::to_vector_of<StringRef>(opMnemonics);
+    math::populateExpansionPatterns(patterns, mnemonics);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1420acaa40d35..615c607efc3c3 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -1,7 +1,9 @@
-// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops=ops=tanh,tan | FileCheck %s --check-prefix=CHECK-FILTER
 
 // CHECK-LABEL: func @tanh
 func.func @tanh(%arg: f32) -> f32 {
+  // CHECK-FILTER-NOT: math.tanh
   %res = math.tanh %arg : f32
   return %res : f32
 }
@@ -27,6 +29,7 @@ func.func @tanh(%arg: f32) -> f32 {
 // CHECK-LABEL: func @vector_tanh
 func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
   // CHECK-NOT: math.tanh
+  // CHECK-FILTER-NOT: math.tanh
   %res = math.tanh %arg : vector<4xf32>
   return %res : vector<4xf32>
 }
@@ -35,6 +38,7 @@ func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
 
 // CHECK-LABEL: func @tan
 func.func @tan(%arg: f32) -> f32 {
+  // CHECK-FILTER-NOT: math.tan
   %res = math.tan %arg : f32
   return %res : f32
 }
@@ -49,6 +53,7 @@ func.func @tan(%arg: f32) -> f32 {
 
 // CHECK-LABEL: func @vector_tan
 func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
+  // CHECK-FILTER-NOT: math.tan
   %res = math.tan %arg : vector<4xf32>
   return %res : vector<4xf32>
 }
@@ -58,6 +63,7 @@ func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
 // -----
 
 func.func @ctlz(%arg: i32) -> i32 {
+  // CHECK-FILTER: math.ctlz
   %res = math.ctlz %arg : i32
   return %res : i32
 }
@@ -112,6 +118,7 @@ func.func @ctlz(%arg: i32) -> i32 {
 // -----
 
 func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  // CHECK-FILTER: math.ctlz
   %res = math.ctlz %arg : vector<4xi32>
   return %res : vector<4xi32>
 }
@@ -145,6 +152,7 @@ func.func @ceilf_func(%a: f64) -> f64 {
   // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
   // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
   // CHECK-NEXT:   return [[ADDF]]
+  // CHECK-FILTER: math.ceil
   %ret = math.ceil %a : f64
   return %ret : f64
 }
@@ -158,6 +166,7 @@ func.func @exp2f_func(%a: f64) -> f64 {
   // CHECK:         [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
   // CHECK:         [[EXP:%.+]]  = math.exp [[MULF]]
   // CHECK:         return [[EXP]]
+  // CHECK-FILTER: math.exp2
   %ret = math.exp2 %a : f64
   return %ret : f64
 }
@@ -813,3 +822,27 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
   %a = math.rsqrt %arg : tensor<*xf32>
   return %a: tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL:    func.func @clampf_scalar_op
+// CHECK-SAME:     (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16)
+// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16
+// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16
+// CHECK:          return %[[V1]] : f16
+
+func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
+  %a = math.clampf %arg to [%min, %max] : f16
+  return %a: f16
+}
+
+// CHECK-LABEL:    func.func @clampf_vector_op
+// CHECK-SAME:     (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>)
+// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
+// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK:          return %[[V1]] : vector<3x4xf32>
+
+func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{
+  %a = math.clampf %arg to [%min, %max] fastmath<fast> : vector<3x4xf32>
+  return %a: vector<3x4xf32>
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 8feadedd1860e..cb10fc4397ffc 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
 // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: func @atan(
@@ -337,3 +337,16 @@ func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>)
   math.isnormal %t : tensor<4x?xf32>
   return
 }
+
+// CHECK-LABEL: func @clampf(
+func.func @clampf(%av: vector<3x4xf32>, %mv: vector<3x4xf32>, %Mv: vector<3x4xf32>,
+                  %as: f32, %ms: f32, %Ms: f32,
+                  %at: tensor<?xf80>, %mt: tensor<?xf80>, %Mt: tensor<?xf80>) {
+  // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] fastmath<fast> : vector<3x4xf32>
+  %rv = math.clampf %av to [%mv, %Mv] fastmath<fast> : vector<3x4xf32>
+  // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : f32
+  %rs = math.clampf %as to [%ms, %Ms] fastmath<none> : f32
+  // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : tensor<?xf80>
+  %rt = math.clampf %at to [%mt, %Mt] : tensor<?xf80>
+  return 
+}
diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index 91e70d1785369..900dff3b5e9f1 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,7 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMathTestPasses
   TestAlgebraicSimplification.cpp
-  TestExpandMath.cpp
   TestPolynomialApproximation.cpp
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
deleted file mode 100644
index efc1acf8bb6cd..0000000000000
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains test passes for expanding math operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestExpandMathPass
-    : public PassWrapper<TestExpandMathPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
-
-  void runOnOperation() override;
-  StringRef getArgument() const final { return "test-expand-math"; }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
-  }
-  StringRef getDescription() const final { return "Test expanding math"; }
-};
-} // namespace
-
-void TestExpandMathPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  populateExpandCtlzPattern(patterns);
-  populateExpandExp2FPattern(patterns);
-  populateExpandTanPattern(patterns);
-  populateExpandSinhPattern(patterns);
-  populateExpandCoshPattern(patterns);
-  populateExpandTanhPattern(patterns);
-  populateExpandAsinhPattern(patterns);
-  populateExpandAcoshPattern(patterns);
-  populateExpandAtanhPattern(patterns);
-  populateExpandFmaFPattern(patterns);
-  populateExpandCeilFPattern(patterns);
-  populateExpandPowFPattern(patterns);
-  populateExpandFPowIPattern(patterns);
-  populateExpandRoundFPattern(patterns);
-  populateExpandRoundEvenPattern(patterns);
-  populateExpandRsqrtPattern(patterns);
-  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
-}
-
-namespace mlir {
-namespace test {
-void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index b599c9d8435d4..3f9d3f2125e1a 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -1,4 +1,4 @@
-// RUN:   mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
+// RUN:   mlir-opt %s -pass-pipeline="builtin.module(func.func(math-expand-ops),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
 // RUN: | mlir-runner                                                      \
 // RUN:     -e main -entry-point-result=void -O0                               ...
[truncated]

@fabianmcg fabianmcg changed the title [mlir][math] Add clampf and clean math ExpandOps [mlir][math] Add clampf and clean math ExpandOps API Jul 29, 2025
@fabianmcg fabianmcg removed mlir:core MLIR Core Infrastructure mlir labels Jul 29, 2025
@fabianmcg fabianmcg requested review from jpienaar and joker-eph July 29, 2025 15:55
@MaheshRavishankar MaheshRavishankar requested a review from bjacob July 30, 2025 09:50
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants