Skip to content

Runtime intrinsics: fix fpext and fptrunc behaviour on Float16/BFloat16 #57160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,9 @@ const _SPECIAL_BUILTINS = Any[
Core._apply_iterate,
]

# Types compatible with fpext/fptrunc
const CORE_FLOAT_TYPES = Union{Core.BFloat16, Float16, Float32, Float64}

function isdefined_effects(𝕃::AbstractLattice, argtypes::Vector{Any})
# consistent if the first arg is immutable
na = length(argtypes)
Expand Down Expand Up @@ -2867,6 +2870,17 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
if !(isprimitivetype(ty) && isprimitivetype(xty))
return ErrorException
end

# fpext and fptrunc have further restrictions on the allowed types.
if f === Intrinsics.fpext &&
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) > Core.sizeof(xty))
return ErrorException
end
if f === Intrinsics.fptrunc &&
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) < Core.sizeof(xty))
return ErrorException
end

return Union{}
end

Expand Down
11 changes: 11 additions & 0 deletions Compiler/test/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1384,3 +1384,14 @@ end |> Compiler.is_nothrow
@test Base.infer_effects() do
@ccall unsafecall()::Cvoid
end == Compiler.EFFECTS_UNKNOWN

# fpext
@test Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float16])
@test Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float64}, Float16])
@test Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float64}, Float32])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float16}, Float16])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float16}, Float32])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float32])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float64])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Int32}, Float16])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Int16])
23 changes: 15 additions & 8 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,16 +672,23 @@ static jl_cgval_t generic_cast(
uint32_t nb = jl_datatype_size(jlto);
Type *to = bitstype_to_llvm((jl_value_t*)jlto, ctx.builder.getContext(), true);
Type *vt = bitstype_to_llvm(v.typ, ctx.builder.getContext(), true);
if (toint)
to = INTT(to, DL);
else
to = FLOATT(to);
if (fromint)
vt = INTT(vt, DL);
else
vt = FLOATT(vt);

// fptrunc fpext depend on the specific floating point format to work
// correctly, and so do not pun their argument types.
if (!(f == fpext || f == fptrunc)) {
if (toint)
to = INTT(to, DL);
else
to = FLOATT(to);
if (fromint)
vt = INTT(vt, DL);
else
vt = FLOATT(vt);
}

if (!to || !vt)
return emit_runtime_call(ctx, f, argv, 2);

Value *from = emit_unbox(ctx, vt, v, v.typ);
if (!CastInst::castIsValid(Op, from, to))
return emit_runtime_call(ctx, f, argv, 2);
Expand Down
145 changes: 91 additions & 54 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ static inline uint16_t float_to_half(float param) JL_NOTSAFEPOINT
uint32_t f;
memcpy(&f, &param, sizeof(float));
if (isnan(param)) {
uint32_t t = 0x8000 ^ (0x8000 & ((uint16_t)(f >> 0x10)));
return t ^ ((uint16_t)(f >> 0xd));
// Match the behaviour of arm64's fcvt or x86's vcvtps2ph by quieting
// all NaNs (avoids creating infinities), preserving the sign, and using
// the upper bits of the payload.
// sign exp quiet payload
return (f>>16 & 0x8000) | 0x7c00 | 0x0200 | (f>>13 & 0x03ff);
}
int i = ((f & ~0x007fffff) >> 23);
uint8_t sh = shifttable[i];
Expand Down Expand Up @@ -761,33 +764,25 @@ static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_N
OP(ty, (c_type*)pr, a); \
}

#define un_fintrinsic_half(OP, name) \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = half_to_float(a); \
if (osize == 16) { \
float R; \
OP(ty, &R, A); \
*(uint16_t*)pr = float_to_half(R); \
} else { \
OP(ty, (uint16_t*)pr, A); \
} \
}
#define un_fintrinsic_half(OP, name) \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) \
JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t *)pa; \
float R, A = half_to_float(a); \
OP(ty, &R, A); \
*(uint16_t *)pr = float_to_half(R); \
}

#define un_fintrinsic_bfloat(OP, name) \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = bfloat_to_float(a); \
if (osize == 16) { \
float R; \
OP(ty, &R, A); \
*(uint16_t*)pr = float_to_bfloat(R); \
} else { \
OP(ty, (uint16_t*)pr, A); \
} \
}
#define un_fintrinsic_bfloat(OP, name) \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) \
JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t *)pa; \
float R, A = bfloat_to_float(a); \
OP(ty, &R, A); \
*(uint16_t *)pr = float_to_bfloat(R); \
}

// float or integer inputs
// OP::Function macro(inputa, inputb)
Expand Down Expand Up @@ -1629,32 +1624,74 @@ cvt_iintrinsic(LLVMUItoFP, uitofp)
cvt_iintrinsic(LLVMFPtoSI, fptosi)
cvt_iintrinsic(LLVMFPtoUI, fptoui)

#define fptrunc(tr, pr, a) \
if (!(osize < 8 * sizeof(a))) \
jl_error("fptrunc: output bitsize must be < input bitsize"); \
else if (osize == 16) { \
if ((jl_datatype_t*)tr == jl_float16_type) \
*(uint16_t*)pr = float_to_half(a); \
else /*if ((jl_datatype_t*)tr == jl_bfloat16_type)*/ \
*(uint16_t*)pr = float_to_bfloat(a); \
} \
else if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
*(double*)pr = a; \
else \
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
#define fpext(tr, pr, a) \
if (!(osize >= 8 * sizeof(a))) \
jl_error("fpext: output bitsize must be >= input bitsize"); \
if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
*(double*)pr = a; \
else \
jl_error("fpext: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
un_fintrinsic_withtype(fptrunc,fptrunc)
un_fintrinsic_withtype(fpext,fpext)
#define fintrinsic_read_float16(p) half_to_float(*(uint16_t *)p)
#define fintrinsic_read_bfloat16(p) bfloat_to_float(*(uint16_t *)p)
#define fintrinsic_read_float32(p) *(float *)p
#define fintrinsic_read_float64(p) *(double *)p

#define fintrinsic_write_float16(p, x) *(uint16_t *)p = float_to_half(x)
#define fintrinsic_write_bfloat16(p, x) *(uint16_t *)p = float_to_bfloat(x)
#define fintrinsic_write_float32(p, x) *(float *)p = x
#define fintrinsic_write_float64(p, x) *(double *)p = x

/*
* aty: Type of value argument (input)
* pa: Pointer to value argument data
* ty: Type argument (output)
* pr: Pointer to result data
*/

static inline void fptrunc(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *pr)
{
unsigned isize = jl_datatype_size(aty), osize = jl_datatype_size(ty);
if (!(osize < isize)) {
jl_error("fptrunc: output bitsize must be < input bitsize");
return;
}

#define fptrunc_convert(in, out) \
else if (aty == jl_##in##_type && ty == jl_##out##_type) \
fintrinsic_write_##out(pr, fintrinsic_read_##in(pa))

if (0)
;
fptrunc_convert(float32, float16);
fptrunc_convert(float64, float16);
fptrunc_convert(float32, bfloat16);
fptrunc_convert(float64, bfloat16);
fptrunc_convert(float64, float32);
else
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
Comment on lines +1663 to +1664
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are changing the permitted combination of behaviors, you will also need to update Compiler/src/tfuncs.jl and src/intrinsics.cpp to match these new rules, so that they all agree on exactly which errors are thrown and for what cases they can occur

#undef fptrunc_convert
}

static inline void fpext(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *pr)
{
unsigned isize = jl_datatype_size(aty), osize = jl_datatype_size(ty);
if (!(osize > isize)) {
jl_error("fpext: output bitsize must be > input bitsize");
return;
}

#define fpext_convert(in, out) \
else if (aty == jl_##in##_type && ty == jl_##out##_type) \
fintrinsic_write_##out(pr, fintrinsic_read_##in(pa))

if (0)
;
fpext_convert(float16, float32);
fpext_convert(float16, float64);
fpext_convert(bfloat16, float32);
fpext_convert(bfloat16, float64);
fpext_convert(float32, float64);
else
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
#undef fpext_convert
}

cvt_iintrinsic(fptrunc, fptrunc)
cvt_iintrinsic(fpext, fpext)


// checked arithmetic
/**
Expand Down
Loading