Skip to content

Commit f725017

Browse files
authored
Implement acos operator in MLIR Math Dialect (#74584)
Required for torch-mlir. Cf. llvm/torch-mlir#2604 "Implement torch.aten.acos".
1 parent b842b1b commit f725017

File tree

4 files changed

+87
-0
lines changed

4 files changed

+87
-0
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,35 @@ def Math_CosOp : Math_FloatUnaryOp<"cos"> {
300300
let hasFolder = 1;
301301
}
302302

303+
//===----------------------------------------------------------------------===//
304+
// AcosOp
305+
//===----------------------------------------------------------------------===//
306+
307+
def Math_AcosOp : Math_FloatUnaryOp<"acos"> {
308+
let summary = "arcus cosine of the specified value";
309+
let description = [{
310+
Syntax:
311+
312+
```
313+
operation ::= ssa-id `=` `math.acos` ssa-use `:` type
314+
```
315+
316+
The `acos` operation computes the arcus cosine of a given value. It takes one
317+
operand of floating point type (i.e., scalar, tensor or vector) and returns one
318+
result of the same type. It has no standard attributes.
319+
320+
Example:
321+
322+
```mlir
323+
// Scalar arcus cosine value.
324+
%a = math.acos %b : f64
325+
```
326+
}];
327+
let hasFolder = 1;
328+
}
329+
330+
331+
303332
//===----------------------------------------------------------------------===//
304333
// SinOp
305334
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
162162
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
163163
MLIRContext *ctx = patterns.getContext();
164164

165+
populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
165166
populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
166167
populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
167168
populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
4141
[](const APInt &a) { return a.abs(); });
4242
}
4343

44+
//===----------------------------------------------------------------------===//
45+
// AcosOp folder
46+
//===----------------------------------------------------------------------===//
47+
48+
OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
49+
return constFoldUnaryOpConditional<FloatAttr>(
50+
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
51+
switch (a.getSizeInBits(a.getSemantics())) {
52+
case 64:
53+
return APFloat(acos(a.convertToDouble()));
54+
case 32:
55+
return APFloat(acosf(a.convertToFloat()));
56+
default:
57+
return {};
58+
}
59+
});
60+
}
61+
4462
//===----------------------------------------------------------------------===//
4563
// AtanOp folder
4664
//===----------------------------------------------------------------------===//

mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
22

3+
// CHECK-DAG: @acos(f64) -> f64 attributes {llvm.readnone}
4+
// CHECK-DAG: @acosf(f32) -> f32 attributes {llvm.readnone}
35
// CHECK-DAG: @atan(f64) -> f64 attributes {llvm.readnone}
46
// CHECK-DAG: @atanf(f32) -> f32 attributes {llvm.readnone}
57
// CHECK-DAG: @erf(f64) -> f64 attributes {llvm.readnone}
@@ -29,6 +31,43 @@
2931
// CHECK-DAG: @ceil(f64) -> f64 attributes {llvm.readnone}
3032
// CHECK-DAG: @ceilf(f32) -> f32 attributes {llvm.readnone}
3133

34+
// CHECK-LABEL: func @acos_caller
35+
// CHECK-SAME: %[[FLOAT:.*]]: f32
36+
// CHECK-SAME: %[[DOUBLE:.*]]: f64
37+
func.func @acos_caller(%float: f32, %double: f64) -> (f32, f64) {
38+
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @acosf(%[[FLOAT]]) : (f32) -> f32
39+
%float_result = math.acos %float : f32
40+
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @acos(%[[DOUBLE]]) : (f64) -> f64
41+
%double_result = math.acos %double : f64
42+
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
43+
return %float_result, %double_result : f32, f64
44+
}
45+
46+
// CHECK-LABEL: func @acos_vec_caller(
47+
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
48+
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
49+
// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
50+
// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
51+
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
52+
// CHECK: %[[OUT0_F32:.*]] = call @acosf(%[[IN0_F32]]) : (f32) -> f32
53+
// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
54+
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
55+
// CHECK: %[[OUT1_F32:.*]] = call @acosf(%[[IN1_F32]]) : (f32) -> f32
56+
// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
57+
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
58+
// CHECK: %[[OUT0_F64:.*]] = call @acos(%[[IN0_F64]]) : (f64) -> f64
59+
// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
60+
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
61+
// CHECK: %[[OUT1_F64:.*]] = call @acos(%[[IN1_F64]]) : (f64) -> f64
62+
// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
63+
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
64+
// CHECK: }
65+
func.func @acos_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
66+
%float_result = math.acos %float : vector<2xf32>
67+
%double_result = math.acos %double : vector<2xf64>
68+
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
69+
}
70+
3271
// CHECK-LABEL: func @atan_caller
3372
// CHECK-SAME: %[[FLOAT:.*]]: f32
3473
// CHECK-SAME: %[[DOUBLE:.*]]: f64

0 commit comments

Comments
 (0)