Skip to content

Commit 40783c1

Browse files
committed
GPU support, BitPacking extension
1 parent 7c78a98 commit 40783c1

File tree

9 files changed

+49
-35
lines changed

9 files changed

+49
-35
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
name = "Microfloats"
22
uuid = "31c70f10-a750-4521-b13c-797315ae2933"
33
authors = ["Anton Oresten <[email protected]> and contributors"]
4-
version = "0.0.6"
4+
version = "0.0.7"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88

9+
[weakdeps]
10+
BitPacking = "b58c8408-13c4-4787-8733-ac52107ded21"
11+
12+
[extensions]
13+
BitPackingExt = "BitPacking"
14+
915
[compat]
1016
Random = "1"
1117
julia = "1.10"

ext/BitPackingExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module BitPackingExt
2+
3+
using Microfloats
4+
using BitPacking
5+
6+
BitPacking.bitwidth(::Type{<:Microfloat}) = Microfloats.n_bits(T)
7+
8+
end

src/MX/MX.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ abstract type MX <: Variant end
22

33
const MX_Microfloat{S,E,M} = Microfloat{S,E,M,MX}
44

5-
const MX_E5M2 = IEEEMicrofloat{1,5,2}
5+
const MX_E5M2 = MX_Microfloat{1,5,2} # technically IEEE 754 compliant
66
const MX_E4M3 = MX_Microfloat{1,4,3}
77
const MX_E3M2 = MX_Microfloat{1,3,2}
88
const MX_E2M3 = MX_Microfloat{1,2,3}
@@ -15,7 +15,7 @@ const NO_NAN_OR_INF = Union{MX_E3M2, MX_E2M3, MX_E2M1}
1515

1616
Base.isinf(::NO_INF) = false
1717
Base.isnan(::NO_NAN) = false
18-
nan(::Type{T}) where T<:NO_NAN = throw(DomainError(T, "$T has no NaN values"))
18+
nan(::Type{T}) where T<:NO_NAN = zero(T)
1919

2020
Base.floatmax(::Type{T}) where T<:NO_NAN_OR_INF = reinterpret(T, exponent_mask(T) | mantissa_mask(T))
2121

@@ -33,8 +33,8 @@ nan(::Type{MX_E8M0}) = reinterpret(MX_E8M0, 0xff)
3333
# Float32 conversion for MX variants:
3434
# - exp=all-ones is "normal" except for the MX NaN sentinel(s)
3535
# - otherwise identical mapping as IEEE
36-
function _float32(x::T) where {T<:MX_Microfloat}
37-
T isa MX_E8M0 && reinterpret(UInt8, x) == 0xff && return NaN32
36+
function _float32(x::T) where {T<:Union{MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0}}
37+
T <: MX_E8M0 && reinterpret(UInt8, x) === 0xff && return NaN32
3838

3939
sgn = UInt32(right_aligned_sign(x))
4040
exp = UInt32(right_aligned_exponent(x))
@@ -73,7 +73,7 @@ function _float32(x::T) where {T<:MX_Microfloat}
7373
end
7474

7575
# Saturating to_microfloat tables for MX (no Infs; overflow -> ±floatmax)
76-
function create_base_shifttable(::Type{T}) where {T<:MX_Microfloat}
76+
function create_base_shifttable(::Type{T}) where {T<:Union{MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0}}
7777
basetable = Vector{T}(undef, 512)
7878
shifttable = Vector{UInt8}(undef, 512)
7979

@@ -89,7 +89,7 @@ function create_base_shifttable(::Type{T}) where {T<:MX_Microfloat}
8989
shifttable[i|0x000+1] = -e + e_shift_subnorm
9090
shifttable[i|0x100+1] = -e + e_shift_subnorm
9191
elseif e < e_overflow_mx
92-
basebits = (e + Int(bias(T))) << exponent_offset(T)
92+
basebits = (e + Int(exponent_bias(T))) << exponent_offset(T)
9393
basetable[i|0x000+1] = reinterpret(T, UInt8(basebits))
9494
basetable[i|0x100+1] = reinterpret(T, UInt8(basebits | Int(sign_mask(T))))
9595
shifttable[i|0x000+1] = n_mantissa_bits(Float32) - n_mantissa_bits(T)
@@ -106,7 +106,7 @@ function create_base_shifttable(::Type{T}) where {T<:MX_Microfloat}
106106
shifttable[i|0x100+1] = n_mantissa_bits(Float32) - n_mantissa_bits(T)
107107
end
108108
end
109-
return reinterpret(UInt8, basetable), shifttable
109+
return (reinterpret(UInt8, basetable)...,), (shifttable...,)
110110
end
111111

112112
# Saturating bounds for MX: use finite extrema

src/conversion/conversion.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
bias(::Type{T}) where T<:Microfloat = UInt32(2^(n_exponent_bits(T) - 1) - 1)
2-
bias_difference(::Type{T}) where T<:Microfloat = UInt32(127 - bias(T))
1+
bias_difference(::Type{T}) where T<:Microfloat = UInt32(exponent_bias(Float32) - exponent_bias(T))
32

43
include("to_microfloat.jl")
54
include("from_microfloat.jl")

src/conversion/to_microfloat.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
e_subnormal(T) = 1 - bias(T) - n_mantissa_bits(T)
2-
e_normal(T) = 1 - bias(T)
3-
e_overflow(T) = (2^n_exponent_bits(T) - 2) - bias(T) + 1
1+
e_subnormal(T) = 1 - exponent_bias(T) - n_mantissa_bits(T)
2+
e_normal(T) = 1 - exponent_bias(T)
3+
e_overflow(T) = (2^n_exponent_bits(T) - 2) - exponent_bias(T) + 1
44

55
function create_base_shifttable(::Type{T}) where {T<:Microfloat}
66

@@ -20,7 +20,7 @@ function create_base_shifttable(::Type{T}) where {T<:Microfloat}
2020
shifttable[i|0x000+1] = -e+e_shift_subnorm
2121
shifttable[i|0x100+1] = -e+e_shift_subnorm
2222
elseif e < e_overflow(T) # Normal numbers just lose precision
23-
basebits = (e + Int(bias(T))) << exponent_offset(T)
23+
basebits = (e + Int(exponent_bias(T))) << exponent_offset(T)
2424
basetable[i|0x000+1] = reinterpret(T, UInt8(basebits))
2525
basetable[i|0x100+1] = reinterpret(T, UInt8(basebits | Int(sign_mask(T))))
2626
shifttable[i|0x000+1] = n_mantissa_bits(Float32)-n_mantissa_bits(T)
@@ -49,7 +49,6 @@ end
4949
basetable, shifttable = create_base_shifttable(T)
5050

5151
quote
52-
isnan(x) && return nan(T) # TODO retain the significant bits for NaN?
5352
f = reinterpret(UInt32, x)
5453

5554
# exponent+sign index into 512-entry tables (9 bits), 1-based
@@ -75,6 +74,6 @@ end
7574
h = h + (UInt8(1) << mantissa_offset(T))
7675
end
7776
end
78-
return reinterpret(T, h)
77+
return ifelse(isnan(x), nan(T), reinterpret(T, h)) # TODO retain the significant bits for NaN?
7978
end
8079
end

src/float-bits.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@ uint(::Type{T}) where T<:Unsigned = T
77
as_uint(x::T) where T<:AbstractFloat = reinterpret(uint(T), x)
88
bit_ones(N, T=UInt8) = (one(uint(T)) << N) - one(uint(T))
99

10-
n_total_bits(::Type{T}) where T<:AbstractFloat = sizeof(T) * 8
11-
n_utilized_bits(::Type{T}) where T<:AbstractFloat = n_sign_bits(T) + n_exponent_bits(T) + n_mantissa_bits(T)
12-
n_rpad_bits(::Type{T}) where T<:AbstractFloat = 0
10+
n_bits(::Type{T}) where T<:AbstractFloat = n_sign_bits(T) + n_exponent_bits(T) + n_mantissa_bits(T)
1311

14-
mantissa_offset(::Type{T}) where T<:AbstractFloat = n_rpad_bits(T)
12+
mantissa_offset(::Type{T}) where T<:AbstractFloat = 0
1513
exponent_offset(::Type{T}) where T<:AbstractFloat = n_mantissa_bits(T) + mantissa_offset(T)
1614
sign_offset(::Type{T}) where T<:AbstractFloat = n_exponent_bits(T) + exponent_offset(T)
1715

@@ -27,6 +25,8 @@ right_aligned_sign(x::T) where T<:AbstractFloat = only_sign(x) >> sign_offset(T)
2725
right_aligned_exponent(x::T) where T<:AbstractFloat = only_exponent(x) >> exponent_offset(T)
2826
right_aligned_mantissa(x::T) where T<:AbstractFloat = only_mantissa(x) >> mantissa_offset(T)
2927

28+
exponent_bias(::Type{T}) where T<:AbstractFloat = UInt32(2^(n_exponent_bits(T) - 1) - 1)
29+
3030
# right_aligned_sign_mask(::Type{T}) where T<:AbstractFloat = bit_ones(n_sign_bits(T), T)
3131
right_aligned_exponent_mask(::Type{T}) where T<:AbstractFloat = bit_ones(n_exponent_bits(T), T)
3232
right_aligned_mantissa_mask(::Type{T}) where T<:AbstractFloat = bit_ones(n_mantissa_bits(T), T)

test/MX/MX_compliance.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@testset "FP8" begin
99

1010
@testset "E4M3" begin
11-
@test Microfloats.bias(MX_E4M3) == 7
11+
@test Microfloats.exponent_bias(MX_E4M3) == 7
1212

1313
@test isfinite(reinterpret(MX_E4M3, 0b0_1111_000))
1414
@test isfinite(reinterpret(MX_E4M3, 0b1_1111_000))
@@ -38,7 +38,7 @@
3838
end
3939

4040
@testset "E5M2" begin
41-
@test Microfloats.bias(MX_E5M2) == 15
41+
@test Microfloats.exponent_bias(MX_E5M2) == 15
4242

4343
@test reinterpret(UInt8, MX_E5M2(Inf)) == 0b0_11111_00
4444
@test reinterpret(UInt8, MX_E5M2(-Inf)) == 0b1_11111_00
@@ -71,7 +71,7 @@
7171
@testset "FP6" begin
7272

7373
@testset "E2M3" begin
74-
@test Microfloats.bias(MX_E2M3) == 1
74+
@test Microfloats.exponent_bias(MX_E2M3) == 1
7575

7676
@test isfinite(reinterpret(MX_E2M3, 0b0_11_000))
7777
@test isfinite(reinterpret(MX_E2M3, 0b1_11_000))
@@ -98,7 +98,7 @@
9898
end
9999

100100
@testset "E3M2" begin
101-
@test Microfloats.bias(MX_E3M2) == 3
101+
@test Microfloats.exponent_bias(MX_E3M2) == 3
102102

103103
@test isfinite(reinterpret(MX_E3M2, 0b0_111_00))
104104
@test isfinite(reinterpret(MX_E3M2, 0b1_111_00))
@@ -129,7 +129,7 @@
129129
@testset "FP4" begin
130130

131131
@testset "E2M1" begin
132-
@test Microfloats.bias(MX_E2M1) == 1
132+
@test Microfloats.exponent_bias(MX_E2M1) == 1
133133

134134
@test isfinite(reinterpret(MX_E2M1, 0b0_11_0))
135135
@test isfinite(reinterpret(MX_E2M1, 0b1_11_0))
@@ -158,7 +158,7 @@
158158

159159
# arithmetic not yet supported for unsigned microfloats
160160
@testset "E8M0" begin
161-
@test Microfloats.bias(MX_E8M0) == 127
161+
@test Microfloats.exponent_bias(MX_E8M0) == 127
162162

163163
#@test floatmax(E8M0) == floatmax(Float32) / 2
164164

test/MX/MX_properties.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,23 @@ end
7676
end
7777

7878
@testset "MX: saturation and NaN/Inf mapping from Float32" begin
79-
for T in (MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0)
79+
for T in (MX_E5M2, MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0)
8080
@testset "$T" begin
8181
fmax = floatmax(T)
8282
# +Inf/-Inf map to ±floatmax (unsigned maps both to +floatmax)
83-
@test T(Inf32) == fmax
84-
if Microfloats.n_sign_bits(T) == 0
85-
@test T(-Inf32) == fmax
86-
else
87-
@test T(-Inf32) == -fmax
83+
if !(T <: MX_E5M2)
84+
@test T(Inf32) == fmax
85+
if Microfloats.n_sign_bits(T) == 0
86+
@test T(-Inf32) == fmax
87+
else
88+
@test T(-Inf32) == -fmax
89+
end
8890
end
8991
# NaN maps to sentinel for E4M3/E8M0, else saturates to floatmax
9092
if T <: Union{MX_E4M3, MX_E5M2, MX_E8M0}
9193
@test isnan(T(NaN32))
9294
else
93-
@test_throws DomainError T(NaN32)
95+
@test iszero(T(NaN32))
9496
end
9597
# Values just beyond floatmax saturate
9698
big = nextfloat(Float32(fmax))
@@ -110,7 +112,7 @@ end
110112
if Microfloats.has_mantissa(T)
111113
u = UInt8(1) << Microfloats.mantissa_offset(T)
112114
x = reinterpret(T, u)
113-
expected = Float32(2.0)^(1 - Microfloats.bias(T) - Microfloats.n_mantissa_bits(T))
115+
expected = Float32(2.0)^(1 - Microfloats.exponent_bias(T) - Microfloats.n_mantissa_bits(T))
114116
@test Float32(x) == expected
115117
end
116118
end

test/Microfloat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272

7373
@testset "IEEE microfloats: subnormals and rounding" begin
7474
@testset for T in TYPES
75-
bias = Microfloats.bias(T)
75+
bias = Microfloats.exponent_bias(T)
7676
M = Microfloats.n_mantissa_bits(T)
7777
mo = Microfloats.mantissa_offset(T)
7878
# Encoding for the minimum positive subnormal (mantissa LSB only)

0 commit comments

Comments
 (0)