-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[CIR][X86] Implement lowering for sqrt builtins #169310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
627bcb3
d095f5a
8c13c6f
17a3e66
01bb815
4a39fd7
ef3fd97
9705673
21119e5
90878ec
3529f40
92d0ac3
0385662
ddcb7b8
1e846e7
233efad
9d940bc
51bbcca
f901f03
e5789b6
8937b12
8a02c50
9923a62
82a9395
6bd3282
8232ce8
bc8e4cc
9284761
8647b5c
4bac65a
b1ff2ab
8843006
ed82423
9e8bec2
961c9f9
e5d1a6d
44ddd79
4dd8aa0
cc5ffa1
6d43c43
47f9b2f
15f1f4f
b12779a
996b7e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -781,9 +781,21 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, | |
| case X86::BI__builtin_ia32_sqrtsh_round_mask: | ||
| case X86::BI__builtin_ia32_sqrtsd_round_mask: | ||
| case X86::BI__builtin_ia32_sqrtss_round_mask: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to insert a call to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I just move the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. We are attempting to keep builtins in the same relative order as they appear in classic codegen and the incubator in order to make the code easier to navigate.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood! |
||
| case X86::BI__builtin_ia32_sqrtpd256: | ||
| case X86::BI__builtin_ia32_sqrtpd: | ||
| case X86::BI__builtin_ia32_sqrtps256: | ||
| case X86::BI__builtin_ia32_sqrtps: | ||
| case X86::BI__builtin_ia32_sqrtph256: | ||
| case X86::BI__builtin_ia32_sqrtph: | ||
| errorNYI("Unimplemented builtin"); | ||
Priyanshu3820 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return {}; | ||
|
||
| case X86::BI__builtin_ia32_sqrtph512: | ||
| case X86::BI__builtin_ia32_sqrtps512: | ||
| case X86::BI__builtin_ia32_sqrtpd512: | ||
| case X86::BI__builtin_ia32_sqrtpd512: { | ||
| mlir::Location loc = getLoc(expr->getExprLoc()); | ||
| mlir::Value arg = ops[0]; | ||
| return builder.create<cir::SqrtOp>(loc, arg.getType(), arg).getResult(); | ||
Priyanshu3820 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| case X86::BI__builtin_ia32_pmuludq128: | ||
| case X86::BI__builtin_ia32_pmuludq256: | ||
| case X86::BI__builtin_ia32_pmuludq512: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| //====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===// | ||
| //====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===// | ||
|
||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
|
|
@@ -30,6 +30,7 @@ | |
| #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" | ||
| #include "mlir/Target/LLVMIR/Export.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include "clang/Basic/LLVM.h" | ||
|
||
| #include "clang/CIR/Dialect/IR/CIRAttrs.h" | ||
| #include "clang/CIR/Dialect/IR/CIRDialect.h" | ||
| #include "clang/CIR/Dialect/IR/CIRTypes.h" | ||
|
|
@@ -45,6 +46,93 @@ | |
| using namespace cir; | ||
| using namespace llvm; | ||
|
|
||
|
|
||
| static std::string getLLVMIntrinsicNameForType(mlir::Type llvmTy) { | ||
|
||
| std::string s; | ||
| { | ||
| llvm::raw_string_ostream os(s); | ||
| os << llvmTy; | ||
| } | ||
| return s; | ||
| } | ||
|
|
||
| // Actual lowering | ||
|
||
| mlir::LogicalResult CIRToLLVMSqrtOpLowering::matchAndRewrite( | ||
| cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor, | ||
| mlir::ConversionPatternRewriter &rewriter) const { | ||
|
|
||
| mlir::Location loc = op.getLoc(); | ||
| mlir::MLIRContext *ctx = rewriter.getContext(); | ||
|
|
||
| mlir::Type cirResTy = op.getResult().getType(); | ||
| mlir::Type llvmResTy = getTypeConverter()->convertType(cirResTy); | ||
| if (!llvmResTy) | ||
| return op.emitOpError( | ||
| "expected LLVM dialect result type for cir.sqrt lowering"); | ||
|
|
||
| Value operand = adaptor.getInput(); | ||
| Value llvmOperand = operand; | ||
| if (operand.getType() != llvmResTy) { | ||
|
||
| llvmOperand = rewriter.create<LLVM::BitcastOp>(loc, llvmResTy, operand); | ||
| } | ||
|
|
||
| // Build the llvm.sqrt.* intrinsic name depending on scalar vs vector result | ||
| std::string intrinsicName = "llvm.sqrt."; | ||
| std::string suffix; | ||
|
|
||
| // If the CIR result type is a vector, include the 'vN' part in the suffix. | ||
| if (auto vec = cirResTy.dyn_cast<cir::VectorType>()) { | ||
| Type elt = vec.getElementType(); | ||
| if (auto f = elt.dyn_cast<cir::FloatType>()) { | ||
| unsigned width = f.getWidth(); | ||
| unsigned n = vec.getNumElements(); | ||
| if (width == 32) | ||
| suffix = "v" + std::to_string(n) + "f32"; | ||
| else if (width == 64) | ||
| suffix = "v" + std::to_string(n) + "f64"; | ||
| else if (width == 16) | ||
| suffix = "v" + std::to_string(n) + "f16"; | ||
| else | ||
| return op.emitOpError("unsupported float width for sqrt"); | ||
| } else { | ||
| return op.emitOpError("vector element must be floating point for sqrt"); | ||
| } | ||
| } else if (auto f = cirResTy.dyn_cast<cir::FloatType>()) { | ||
| // Scalar float | ||
| unsigned width = f.getWidth(); | ||
| if (width == 32) | ||
| suffix = "f32"; | ||
| else if (width == 64) | ||
| suffix = "f64"; | ||
| else if (width == 16) | ||
| suffix = "f16"; | ||
| else | ||
| return op.emitOpError("unsupported float width for sqrt"); | ||
| } else { | ||
| return op.emitOpError("unsupported type for cir.sqrt lowering"); | ||
| } | ||
|
|
||
| intrinsicName += suffix; | ||
|
|
||
| // Ensure the llvm intrinsic function exists at module scope. Insert it at | ||
| // the start of the module body using an insertion guard. | ||
| ModuleOp module = op->getParentOfType<ModuleOp>(); | ||
| if (!module.lookupSymbol<LLVM::LLVMFuncOp>(intrinsicName)) { | ||
| OpBuilder::InsertionGuard guard(rewriter); | ||
| rewriter.setInsertionPointToStart(module.getBody()); | ||
| auto llvmFnType = LLVM::LLVMFunctionType::get(ctx, llvmResTy, {llvmResTy}, | ||
| /*isVarArg=*/false); | ||
| rewriter.create<LLVM::LLVMFuncOp>(loc, intrinsicName, llvmFnType); | ||
| } | ||
|
|
||
| // Create the call and replace cir.sqrt | ||
| auto callee = SymbolRefAttr::get(ctx, intrinsicName); | ||
| rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, llvmResTy, callee, | ||
| ArrayRef<Value>{llvmOperand}); | ||
|
|
||
| return mlir::success(); | ||
Priyanshu3820 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| namespace cir { | ||
| namespace direct { | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,19 @@ | |
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include "clang/CIR/Dialect/IR/CIRDialect.h" | ||
|
|
||
| namespace cir { | ||
| class SqrtOp; | ||
| } | ||
|
|
||
| class CIRToLLVMSqrtOpLowering : public mlir::OpConversionPattern<cir::SqrtOp> { | ||
|
||
| public: | ||
| using mlir::OpConversionPattern<cir::SqrtOp>::OpConversionPattern; | ||
|
|
||
| mlir::LogicalResult | ||
| matchAndRewrite(cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor, | ||
| mlir::ConversionPatternRewriter &rewriter) const override; | ||
| }; | ||
|
|
||
| namespace cir { | ||
|
|
||
| namespace direct { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| #include <immintrin.h> | ||
| // Test X86-specific sqrt builtins | ||
|
|
||
| // RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir | ||
| // RUN: FileCheck --input-file=%t.cir %s | ||
Priyanshu3820 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // Test __builtin_ia32_sqrtph512 | ||
| __m512h test_sqrtph512(__m512h a) { | ||
| return __builtin_ia32_sqrtph512(a); | ||
| } | ||
| // CHECK: cir.func @test_sqrtph512 | ||
| // CHECK: [[RES:%.*]] = cir.sqrt {{%.*}} : !cir.vector<!cir.fp16 x 32> | ||
| // CHECK: cir.return [[RES]] | ||
|
|
||
| // Test __builtin_ia32_sqrtps512 | ||
| __m512 test_sqrtps512(__m512 a) { | ||
| return __builtin_ia32_sqrtps512(a); | ||
| } | ||
| // CHECK: cir.func @test_sqrtps512 | ||
| // CHECK: [[RES:%.*]] = cir.sqrt {{%.*}} : !cir.vector<!cir.float x 16> | ||
| // CHECK: cir.return [[RES]] | ||
|
|
||
| // Test __builtin_ia32_sqrtpd512 | ||
| __m512d test_sqrtpd512(__m512d a) { | ||
| return __builtin_ia32_sqrtpd512(a); | ||
| } | ||
| // CHECK: cir.func @test_sqrtpd512 | ||
| // CHECK: [[RES:%.*]] = cir.sqrt {{%.*}} : !cir.vector<!cir.double x 8> | ||
| // CHECK: cir.return [[RES]] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a description with an example of the CIR format?
We're trying to standardize the descriptions of all our operations to include an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had done that initially but removed the example after seeing other floating point ops didn't have examples too. Added it again now.