diff --git a/lib/SPIRVProducerPass.cpp b/lib/SPIRVProducerPass.cpp index aa6d242b6..d098c25cb 100644 --- a/lib/SPIRVProducerPass.cpp +++ b/lib/SPIRVProducerPass.cpp @@ -4516,7 +4516,42 @@ SPIRVID SPIRVProducerPassImpl::GenerateInstructionFromCall(CallInst *Call) { glsl::ExtInst EInst = Builtins::getDirectOrIndirectExtInstEnum(func_info); // Do not replace functions with implementations. - if (EInst && Call->getCalledFunction()->isDeclaration()) { + if (EInst && EInst == glsl::ExtInst::ExtInstFma && + !clspv::Option::UnsafeMath() && + clspv::Option::UseNativeBuiltins().count( + clspv::Builtins::BuiltinType::kFma) == 0 && + Call->getType()->getScalarType()->isHalfTy()) { + Type *FP16Ty = Call->getType(); + + Type *FP32Ty = Type::getFloatTy(Context); + if (FP16Ty->isVectorTy()) { + FP32Ty = FixedVectorType::get( + FP32Ty, dyn_cast(FP16Ty)->getNumElements()); + } + + SPIRVOperandVec Ops; + Ops.clear(); + Ops << FP32Ty << Call->getOperand(0); + auto a_f32 = addSPIRVInst(spv::OpFConvert, Ops); + + Ops.clear(); + Ops << FP32Ty << Call->getOperand(1); + auto b_f32 = addSPIRVInst(spv::OpFConvert, Ops); + + Ops.clear(); + Ops << FP32Ty << Call->getOperand(2); + auto c_f32 = addSPIRVInst(spv::OpFConvert, Ops); + + Ops.clear(); + Ops << FP32Ty << getOpExtInstImportID() << EInst << a_f32 << b_f32 + << c_f32; + auto fma_f32 = addSPIRVInst(spv::OpExtInst, Ops); + + Ops.clear(); + Ops << FP16Ty << fma_f32; + RID = addSPIRVInst(spv::OpFConvert, Ops); + + } else if (EInst && Call->getCalledFunction()->isDeclaration()) { SPIRVID ExtInstImportID = getOpExtInstImportID(); // diff --git a/test/half-fma-native.cl b/test/half-fma-native.cl new file mode 100644 index 000000000..829046978 --- /dev/null +++ b/test/half-fma-native.cl @@ -0,0 +1,18 @@ +// RUN: clspv %target %s -o %t.spv --use-native-builtins=fma +// RUN: spirv-dis -o %t2.spvasm %t.spv +// RUN: FileCheck %s < %t2.spvasm +// RUN: spirv-val --target-env vulkan1.0 %t.spv + +// CHECK: OpCapability Float16 +// CHECK: [[half:%[^ ]+]] = OpTypeFloat 16 +// CHECK: [[a:%[^ ]+]] = OpLoad [[half]] +// CHECK: [[b:%[^ ]+]] = OpLoad [[half]] +// CHECK: [[c:%[^ ]+]] = OpLoad [[half]] +// CHECK: [[fma:%[^ ]+]] = OpExtInst [[half]] {{.*}} Fma [[a]] [[b]] [[c]] +// CHECK: OpStore {{.*}} [[fma]] +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void foo(global half *dst, global half *srcA, global half *srcB, global half *srcC) { + int gid = get_global_id(0); + dst[gid] = srcA[gid] * srcB[gid] + srcC[gid]; +} diff --git a/test/half-fma.cl b/test/half-fma.cl index 7d5acaffc..4c44f6780 100644 --- a/test/half-fma.cl +++ b/test/half-fma.cl @@ -5,10 +5,15 @@ // CHECK: OpCapability Float16 // CHECK: [[half:%[^ ]+]] = OpTypeFloat 16 +// CHECK: [[float:%[^ ]+]] = OpTypeFloat 32 // CHECK: [[a:%[^ ]+]] = OpLoad [[half]] // CHECK: [[b:%[^ ]+]] = OpLoad [[half]] // CHECK: [[c:%[^ ]+]] = OpLoad [[half]] -// CHECK: [[fma:%[^ ]+]] = OpExtInst [[half]] {{.*}} Fma [[a]] [[b]] [[c]] +// CHECK: [[a32:%[^ ]+]] = OpFConvert [[float]] [[a]] +// CHECK: [[b32:%[^ ]+]] = OpFConvert [[float]] [[b]] +// CHECK: [[c32:%[^ ]+]] = OpFConvert [[float]] [[c]] +// CHECK: [[fma32:%[^ ]+]] = OpExtInst [[float]] {{.*}} Fma [[a32]] [[b32]] [[c32]] +// CHECK: [[fma:%[^ ]+]] = OpFConvert [[half]] [[fma32]] // CHECK: OpStore {{.*}} [[fma]] #pragma OPENCL EXTENSION cl_khr_fp16 : enable