-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Add cosh op to the math dialect. #75153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-math Author: Sungsoon Cho (godot73) ChangesFull diff: https://github.com/llvm/llvm-project/pull/75153.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 9742d3d936dff5..b9daa91b28a9bd 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -327,7 +327,26 @@ def Math_AcosOp : Math_FloatUnaryOp<"acos"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// CoshOp
+//===----------------------------------------------------------------------===//
+def Math_CoshOp : Math_FloatUnaryOp<"cosh"> {
+ let summary = "hyperbolic cosine of the specified value";
+ let description = [{
+ The `cosh` operation computes the hyperbolic cosine. It takes one operand
+ of floating point type (i.e., scalar, tensor or vector) and returns one
+ result of the same type. It has no standard attributes.
+
+ Example:
+
+ ```mlir
+ // Scalar hyperbolic cosine value.
+ %a = math.cosh %b : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
//===----------------------------------------------------------------------===//
// SinOp
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 27c2cb93520714..6e30c07de4d57e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -168,6 +168,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
+ populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 066a21c76f7d1c..6b8c3a53a422fa 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -144,6 +144,24 @@ OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// CoshOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ switch (a.getSizeInBits(a.getSemantics())) {
+ case 64:
+ return APFloat(cosh(a.convertToDouble()));
+ case 32:
+ return APFloat(coshf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// SinOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index f0c4512cbfdcc7..eb9226dee2619d 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -24,6 +24,8 @@
// CHECK-DAG: @truncf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @cos(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @cosf(f32) -> f32 attributes {llvm.readnone}
+// CHECK-DAG: @cosh(f64) -> f64 attributes {llvm.readnone}
+// CHECK-DAG: @coshf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @sin(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @sinf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @floor(f64) -> f64 attributes {llvm.readnone}
@@ -127,6 +129,18 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}
+// CHECK-LABEL: func @cosh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @cosh_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @coshf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.cosh %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cosh(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.cosh %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
// CHECK-LABEL: func @atan2_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
|
@llvm/pr-subscribers-mlir Author: Sungsoon Cho (godot73) ChangesFull diff: https://github.com/llvm/llvm-project/pull/75153.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 9742d3d936dff5..b9daa91b28a9bd 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -327,7 +327,26 @@ def Math_AcosOp : Math_FloatUnaryOp<"acos"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// CoshOp
+//===----------------------------------------------------------------------===//
+def Math_CoshOp : Math_FloatUnaryOp<"cosh"> {
+ let summary = "hyperbolic cosine of the specified value";
+ let description = [{
+ The `cosh` operation computes the hyperbolic cosine. It takes one operand
+ of floating point type (i.e., scalar, tensor or vector) and returns one
+ result of the same type. It has no standard attributes.
+
+ Example:
+
+ ```mlir
+ // Scalar hyperbolic cosine value.
+ %a = math.cosh %b : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
//===----------------------------------------------------------------------===//
// SinOp
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 27c2cb93520714..6e30c07de4d57e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -168,6 +168,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
+ populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 066a21c76f7d1c..6b8c3a53a422fa 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -144,6 +144,24 @@ OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// CoshOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ switch (a.getSizeInBits(a.getSemantics())) {
+ case 64:
+ return APFloat(cosh(a.convertToDouble()));
+ case 32:
+ return APFloat(coshf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// SinOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index f0c4512cbfdcc7..eb9226dee2619d 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -24,6 +24,8 @@
// CHECK-DAG: @truncf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @cos(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @cosf(f32) -> f32 attributes {llvm.readnone}
+// CHECK-DAG: @cosh(f64) -> f64 attributes {llvm.readnone}
+// CHECK-DAG: @coshf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @sin(f64) -> f64 attributes {llvm.readnone}
// CHECK-DAG: @sinf(f32) -> f32 attributes {llvm.readnone}
// CHECK-DAG: @floor(f64) -> f64 attributes {llvm.readnone}
@@ -127,6 +129,18 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}
+// CHECK-LABEL: func @cosh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @cosh_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @coshf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.cosh %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cosh(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.cosh %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
// CHECK-LABEL: func @atan2_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
|
Similar to #74584, this PR is required to implement lowering of torch.aten.cosh. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
@@ -127,6 +129,18 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) { | |||
return %float_result, %double_result : f32, f64 | |||
} | |||
|
|||
// CHECK-LABEL: func @cosh_caller |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For other operators, there are also tests which apply the operator on a vector. Since this functionality does not require a specific implementation for each operator, I suppose it is ok not to duplicate the test case for each operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@stellaraccident, can you please merge this PR? |
@stellaraccident @vivekkhandelwal1 I already did it. |
Thanks! |
No description provided.