Skip to content

Commit

Permalink
add an always compliant version of fma on fp16 (#1374)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjodinchr authored Jun 27, 2024
1 parent 7f7d055 commit 720be7f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
37 changes: 36 additions & 1 deletion lib/SPIRVProducerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4516,7 +4516,42 @@ SPIRVID SPIRVProducerPassImpl::GenerateInstructionFromCall(CallInst *Call) {
glsl::ExtInst EInst = Builtins::getDirectOrIndirectExtInstEnum(func_info);

// Do not replace functions with implementations.
if (EInst && Call->getCalledFunction()->isDeclaration()) {
if (EInst && EInst == glsl::ExtInst::ExtInstFma &&
!clspv::Option::UnsafeMath() &&
clspv::Option::UseNativeBuiltins().count(
clspv::Builtins::BuiltinType::kFma) == 0 &&
Call->getType()->getScalarType()->isHalfTy()) {
Type *FP16Ty = Call->getType();

Type *FP32Ty = Type::getFloatTy(Context);
if (FP16Ty->isVectorTy()) {
FP32Ty = FixedVectorType::get(
FP32Ty, dyn_cast<FixedVectorType>(FP16Ty)->getNumElements());
}

SPIRVOperandVec Ops;
Ops.clear();
Ops << FP32Ty << Call->getOperand(0);
auto a_f32 = addSPIRVInst(spv::OpFConvert, Ops);

Ops.clear();
Ops << FP32Ty << Call->getOperand(1);
auto b_f32 = addSPIRVInst(spv::OpFConvert, Ops);

Ops.clear();
Ops << FP32Ty << Call->getOperand(2);
auto c_f32 = addSPIRVInst(spv::OpFConvert, Ops);

Ops.clear();
Ops << FP32Ty << getOpExtInstImportID() << EInst << a_f32 << b_f32
<< c_f32;
auto fma_f32 = addSPIRVInst(spv::OpExtInst, Ops);

Ops.clear();
Ops << FP16Ty << fma_f32;
RID = addSPIRVInst(spv::OpFConvert, Ops);

} else if (EInst && Call->getCalledFunction()->isDeclaration()) {
SPIRVID ExtInstImportID = getOpExtInstImportID();

//
Expand Down
18 changes: 18 additions & 0 deletions test/half-fma-native.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: clspv %target %s -o %t.spv --use-native-builtins=fma
// RUN: spirv-dis -o %t2.spvasm %t.spv
// RUN: FileCheck %s < %t2.spvasm
// RUN: spirv-val --target-env vulkan1.0 %t.spv

// CHECK: OpCapability Float16
// CHECK: [[half:%[^ ]+]] = OpTypeFloat 16
// CHECK: [[a:%[^ ]+]] = OpLoad [[half]]
// CHECK: [[b:%[^ ]+]] = OpLoad [[half]]
// CHECK: [[c:%[^ ]+]] = OpLoad [[half]]
// CHECK: [[fma:%[^ ]+]] = OpExtInst [[half]] {{.*}} Fma [[a]] [[b]] [[c]]
// CHECK: OpStore {{.*}} [[fma]]
#pragma OPENCL EXTENSION cl_khr_fp16 : enable

kernel void foo(global half *dst, global half *srcA, global half *srcB, global half *srcC) {
int gid = get_global_id(0);
dst[gid] = srcA[gid] * srcB[gid] + srcC[gid];
}
7 changes: 6 additions & 1 deletion test/half-fma.cl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

// CHECK: OpCapability Float16
// CHECK: [[half:%[^ ]+]] = OpTypeFloat 16
// CHECK: [[float:%[^ ]+]] = OpTypeFloat 32
// CHECK: [[a:%[^ ]+]] = OpLoad [[half]]
// CHECK: [[b:%[^ ]+]] = OpLoad [[half]]
// CHECK: [[c:%[^ ]+]] = OpLoad [[half]]
// CHECK: [[fma:%[^ ]+]] = OpExtInst [[half]] {{.*}} Fma [[a]] [[b]] [[c]]
// CHECK: [[a32:%[^ ]+]] = OpFConvert [[float]] [[a]]
// CHECK: [[b32:%[^ ]+]] = OpFConvert [[float]] [[b]]
// CHECK: [[c32:%[^ ]+]] = OpFConvert [[float]] [[c]]
// CHECK: [[fma32:%[^ ]+]] = OpExtInst [[float]] {{.*}} Fma [[a32]] [[b32]] [[c32]]
// CHECK: [[fma:%[^ ]+]] = OpFConvert [[half]] [[fma32]]
// CHECK: OpStore {{.*}} [[fma]]
#pragma OPENCL EXTENSION cl_khr_fp16 : enable

Expand Down

0 comments on commit 720be7f

Please sign in to comment.