Skip to content

Conversation

@hankluo6
Copy link
Contributor

Fixes #151786

fptosi produces poison when the input is Inf, and any subsequent use leads to undefined behavior. This patch adds a safe path, similar to the existing round expansion, for large or special inputs and avoids the UB.

@llvmbot
Copy link
Member

llvmbot commented Nov 30, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Hank (hankluo6)

Changes

Fixes #151786

fptosi produces poison when the input is Inf, and any subsequent use leads to undefined behavior. This patch adds a safe path, similar to the existing round expansion, for large or special inputs and avoids the UB.


Full diff: https://github.com/llvm/llvm-project/pull/170028.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (+33-1)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+10-1)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index cd68039d0d964..e9f4811aae3fe 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -232,6 +232,37 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
+
+  auto operandETy = getElementTypeOrSelf(opType);
+  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
+  unsigned mantissaWidth =
+      llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+
+  Type iTy = rewriter.getIntegerType(bitWidth);
+  if (auto shapedTy = dyn_cast<ShapedType>(opType))
+    iTy = shapedTy.clone(iTy);
+
+  Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
+  Value cBias =
+      createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
+  Value cExpMask =
+      createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);
+
+  // Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
+  // falls into one of these categories:
+  //   - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
+  //     numbers are already integral, or
+  //   - a special value (NaN or ±Inf), which also satisfies this exponent
+  //     condition.
+  // For all such cases, `ceilf(x)` is defined to return `x` directly.
+  Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
+  Value operandExp = arith::AndIOp::create(
+      b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
+  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
+  Value isSpecialValOrLargeVal = arith::CmpIOp::create(
+      b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);
+
   Value fpFixedConvert = createTruncatedFPValue(operand, b);
 
   // Creating constants for later use.
@@ -243,7 +274,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value incrValue =
       arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
   rewriter.replaceOp(op, ret);
   return success();
 }
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 615c607efc3c3..75f8e65b334a2 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -145,13 +145,22 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
 func.func @ceilf_func(%a: f64) -> f64 {
   // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
   // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 1.000
+  // CHECK-DAG:   [[C52:%.*]] = arith.constant 52
+  // CHECK-DAG:   [[C1023:%.*]] = arith.constant 1023
+  // CHECK-DAG:   [[EXP_MASK:%.*]] = arith.constant 2047
+  // CHECK-NEXT:   [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
+  // CHECK-NEXT:   [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
+  // CHECK-NEXT:   [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
+  // CHECK-NEXT:   [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
+  // CHECK-NEXT:   [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
   // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
   // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
   // CHECK-NEXT:   [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
   // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
   // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
   // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
-  // CHECK-NEXT:   return [[ADDF]]
+  // CHECK-NEXT:   [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
+  // CHECK-NEXT:   return [[RESULT]]
   // CHECK-FILTER: math.ceil
   %ret = math.ceil %a : f64
   return %ret : f64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] tosa.powf + tosa.ceil returns incorrect result with -test-expand-math

2 participants