diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index caa047a51b689..a0a553d6b50fd 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -4756,6 +4756,27 @@ class CIR_UnaryFPToFPBuiltinOp let llvmOp = llvmOpName; } +def CIR_SqrtOp : CIR_UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp"> { + let summary = "Floating-point square root operation"; + + let description = [{ + Computes the square root of a floating-point value or vector. + + The input must be either: + • a floating-point scalar type, or + • a vector whose element type is floating-point. + + The result type must match the input type exactly. + + Examples: + // scalar + %r = cir.sqrt %x : !cir.fp64 + + // vector + %v = cir.sqrt %vec : !cir.vector + }]; +} + def CIR_ACosOp : CIR_UnaryFPToFPBuiltinOp<"acos", "ACosOp"> { let summary = "Computes the arcus cosine of the specified value"; let description = [{ diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp index 1c1ef4da20b0d..fb17e31bf36d6 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp @@ -1347,13 +1347,17 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, case X86::BI__builtin_ia32_sqrtsh_round_mask: case X86::BI__builtin_ia32_sqrtsd_round_mask: case X86::BI__builtin_ia32_sqrtss_round_mask: - case X86::BI__builtin_ia32_sqrtph512: - case X86::BI__builtin_ia32_sqrtps512: - case X86::BI__builtin_ia32_sqrtpd512: cgm.errorNYI(expr->getSourceRange(), std::string("unimplemented X86 builtin call: ") + getContext().BuiltinInfo.getName(builtinID)); return {}; + case X86::BI__builtin_ia32_sqrtph512: + case X86::BI__builtin_ia32_sqrtps512: + case X86::BI__builtin_ia32_sqrtpd512: { + mlir::Location loc = getLoc(expr->getExprLoc()); + mlir::Value arg = ops[0]; + return cir::SqrtOp::create(builder, loc, arg.getType(), arg).getResult(); + } case X86::BI__builtin_ia32_pmuludq128: case X86::BI__builtin_ia32_pmuludq256: case X86::BI__builtin_ia32_pmuludq512: { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 97bd3cf850daa..bda74b7fbdf6e 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -186,6 +186,14 @@ mlir::LogicalResult CIRToLLVMCopyOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMSqrtOpLowering::matchAndRewrite( + cir::SqrtOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Type resTy = typeConverter->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, adaptor.getSrc()); + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMCosOpLowering::matchAndRewrite( cir::CosOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c b/clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c new file mode 100644 index 0000000000000..d540e9c227e67 --- /dev/null +++ b/clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c @@ -0,0 +1,45 @@ +// Test X86-specific sqrt builtins + +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --check-prefix=LLVM --input-file=%t-cir.ll %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +avx512f -target-feature +avx512fp16 -emit-llvm %s -o %t.ll +// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s + +typedef float __m512 __attribute__((__vector_size__(64), __aligned__(64))); +typedef double __m512d __attribute__((__vector_size__(64), __aligned__(64))); +typedef _Float16 __m512h __attribute__((__vector_size__(64), __aligned__(64))); + +// Test __builtin_ia32_sqrtph512 +__m512h test_sqrtph512(__m512h a) { + return __builtin_ia32_sqrtph512(a, 4); +} +// CIR-LABEL: cir.func {{.*}}@test_sqrtph512 +// CIR: cir.sqrt {{%.*}} : !cir.vector<32 x !cir.f16> +// LLVM-LABEL: define {{.*}} @test_sqrtph512 +// LLVM: call <32 x half> @llvm.sqrt.v32f16 +// OGCG-LABEL: define {{.*}} @test_sqrtph512 +// OGCG: call <32 x half> @llvm.sqrt.v32f16 + +// Test __builtin_ia32_sqrtps512 +__m512 test_sqrtps512(__m512 a) { + return __builtin_ia32_sqrtps512(a, 4); +} +// CIR-LABEL: cir.func {{.*}}@test_sqrtps512 +// CIR: cir.sqrt {{%.*}} : !cir.vector<16 x !cir.float> +// LLVM-LABEL: define {{.*}} @test_sqrtps512 +// LLVM: call <16 x float> @llvm.sqrt.v16f32 +// OGCG-LABEL: define {{.*}} @test_sqrtps512 +// OGCG: call <16 x float> @llvm.sqrt.v16f32 + +// Test __builtin_ia32_sqrtpd512 +__m512d test_sqrtpd512(__m512d a) { + return __builtin_ia32_sqrtpd512(a, 4); +} +// CIR-LABEL: cir.func {{.*}}@test_sqrtpd512 +// CIR: cir.sqrt {{%.*}} : !cir.vector<8 x !cir.double> +// LLVM-LABEL: define {{.*}} @test_sqrtpd512 +// LLVM: call <8 x double> @llvm.sqrt.v8f64 +// OGCG-LABEL: define {{.*}} @test_sqrtpd512 +// OGCG: call <8 x double> @llvm.sqrt.v8f64 \ No newline at end of file