Skip to content

Commit 3abb5f6

Browse files
committed
add NVFP4, remove rpad by aligning bits to right
1 parent 87524be commit 3abb5f6

File tree

10 files changed

+103
-98
lines changed

10 files changed

+103
-98
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Microfloats"
22
uuid = "31c70f10-a750-4521-b13c-797315ae2933"
33
authors = ["Anton Oresten <[email protected]> and contributors"]
4-
version = "0.0.3"
4+
version = "0.0.4"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ const E2M1 = Microfloat(1, 2, 1, :MX)
3636
const E8M0 = Microfloat(0, 8, 0, :MX)
3737
```
3838

39+
For `INT8`, see `FixedPointNumbers.Q1f6`.
40+
3941
## Installation
4042

4143
```julia
@@ -49,3 +51,4 @@ Pkg.add("Microfloats")
4951
- [MicroFloatingPoints.jl](https://github.com/goualard-f/MicroFloatingPoints.jl)
5052
- [DLFP8Types.jl](https://github.com/chengchingwen/DLFP8Types.jl)
5153
- [Float8s.jl](https://github.com/JuliaMath/Float8s.jl)
54+
- [FixedPointNumbers.jl](https://github.com/JuliaMath/FixedPointNumbers.jl)

src/Microfloat.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
primitive type Microfloat{S,E,M,V} <: AbstractFloat 8 end
1+
abstract type Variant end
2+
abstract type IEEE <: Variant end
23

3-
const SignedMicrofloat = Microfloat{1}
4-
const UnsignedMicrofloat = Microfloat{0}
4+
primitive type Microfloat{S,E,M,V} <: AbstractFloat 8 end
55

66
"""
7-
Microfloat(S, E, M, V=:IEEE)
7+
Microfloat(S, E, M, V=IEEE)
88
99
Create a new `Microfloat` type with `S` sign bits, `E` exponent bits, and `M` mantissa bits.
1010
1111
This "type constructor" ensures that the resulting type is legal.
1212
1313
The `V` argument can be set to `:MX` to create a Microscaling Format (MX) type.
1414
"""
15-
function Microfloat(S::Int, E::Int, M::Int, V::Symbol=:IEEE)
15+
function Microfloat(S::Int, E::Int, M::Int, V::Type{<:Variant}=IEEE)
1616
S in (0, 1) || throw(ArgumentError("sign bit must be 0 or 1"))
1717
E >= 1 || throw(ArgumentError("number of exponent bits must be non-negative"))
1818
M >= 0 || throw(ArgumentError("number of mantissa bits must be non-negative"))
@@ -47,7 +47,7 @@ Base.floatmin(::Type{T}) where T<:Microfloat = n_exponent_bits(T) > 1 ? reinterp
4747
Base.floatmax(::Type{T}) where T<:Microfloat = reinterpret(T, bit_ones(n_exponent_bits(T) - 1) << (exponent_offset(T) + 1) | mantissa_mask(T))
4848

4949
Base.typemin(::Type{T}) where T<:Microfloat = -inf(T)
50-
Base.typemin(::Type{T}) where T<:UnsignedMicrofloat = zero(T)
50+
Base.typemin(::Type{T}) where T<:Microfloat{0} = zero(T)
5151

5252
Base.typemax(::Type{T}) where T<:Microfloat = inf(T)
5353

src/Microfloats.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ include("float-bits.jl")
44

55
include("Microfloat.jl")
66
export Microfloat
7-
export SignedMicrofloat, UnsignedMicrofloat
7+
export IEEE
88

9-
include("MX/MX.jl")
9+
include("microscaled/microscaled.jl")
10+
export MX, NV
11+
export MX_E5M2, MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0, NV_E2M1
1012

1113
include("conversion/conversion.jl")
1214

src/conversion/to_microfloat.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ function create_base_shifttable(::Type{T}) where {T<:Microfloat}
4343
return reinterpret(UInt8, basetable), shifttable
4444
end
4545

46-
@generated function (::Type{T})(x::Float32) where {S,E,M,T<:Microfloat{S,E,M}}
46+
(::Type{T})(x::Float32) where {S,E,M,T<:Microfloat{S,E,M}} = T{IEEE}(x)
47+
48+
@generated function (::Type{T})(x::Float32) where {S,E,M,V,T<:Microfloat{S,E,M,V}}
4749
basetable, shifttable = create_base_shifttable(T)
4850

4951
quote

src/float-bits.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ bit_ones(N, T=UInt8) = (one(uint(T)) << N) - one(uint(T))
99

1010
n_total_bits(::Type{T}) where T<:AbstractFloat = sizeof(T) * 8
1111
n_utilized_bits(::Type{T}) where T<:AbstractFloat = n_sign_bits(T) + n_exponent_bits(T) + n_mantissa_bits(T)
12-
n_padding_bits(::Type{T}) where T<:AbstractFloat = n_total_bits(T) - n_utilized_bits(T)
12+
n_rpad_bits(::Type{T}) where T<:AbstractFloat = 0
1313

14-
mantissa_offset(::Type{T}) where T<:AbstractFloat = n_padding_bits(T)
14+
mantissa_offset(::Type{T}) where T<:AbstractFloat = n_rpad_bits(T)
1515
exponent_offset(::Type{T}) where T<:AbstractFloat = n_mantissa_bits(T) + mantissa_offset(T)
1616
sign_offset(::Type{T}) where T<:AbstractFloat = n_exponent_bits(T) + exponent_offset(T)
1717

src/MX/MX.jl renamed to src/microscaled/microscaled.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
1-
# src/MX/MX.jl
1+
abstract type MX <: Variant end
2+
abstract type NV <: Variant end
3+
const Microscaled = Union{MX, NV}
24

3-
const MXMicrofloat{S,E,M} = Microfloat{S,E,M,:MX}
4-
const SignedMXMicrofloat = MXMicrofloat{1}
5-
const UnsignedMXMicrofloat = MXMicrofloat{0}
5+
const MXMicrofloat{S,E,M} = Microfloat{S,E,M,MX}
6+
const NVMicrofloat{S,E,M} = Microfloat{S,E,M,NV}
7+
const MicroscaledMicrofloat{S,E,M} = Microfloat{S,E,M,<:Microscaled}
68

79
const MX_E5M2 = MXMicrofloat{1,5,2}
810
const MX_E4M3 = MXMicrofloat{1,4,3}
911
const MX_E3M2 = MXMicrofloat{1,3,2}
1012
const MX_E2M3 = MXMicrofloat{1,2,3}
1113
const MX_E2M1 = MXMicrofloat{1,2,1}
1214
const MX_E8M0 = MXMicrofloat{0,8,0}
15+
const NV_E2M1 = NVMicrofloat{1,2,1}
1316

14-
const MX_NO_INF = Union{MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0}
15-
const MX_NO_NAN = Union{MX_E3M2, MX_E2M3, MX_E2M1}
16-
const MX_NO_NAN_OR_INF = Union{MX_E3M2, MX_E2M3, MX_E2M1}
17+
const NO_INF = Union{MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0, NV_E2M1}
18+
const NO_NAN = Union{MX_E3M2, MX_E2M3, MX_E2M1, NV_E2M1}
19+
const NO_NAN_OR_INF = Union{MX_E3M2, MX_E2M3, MX_E2M1, NV_E2M1}
1720

18-
Base.isinf(::MX_NO_INF) = false
19-
Base.isnan(::MX_NO_NAN) = false
20-
nan(::Type{T}) where T<:MX_NO_NAN = throw(DomainError(T, "$T has no NaN values"))
21+
Base.isinf(::NO_INF) = false
22+
Base.isnan(::NO_NAN) = false
23+
nan(::Type{T}) where T<:NO_NAN = throw(DomainError(T, "$T has no NaN values"))
2124

22-
Base.floatmax(::Type{T}) where T<:MX_NO_NAN_OR_INF = reinterpret(T, exponent_mask(T) | mantissa_mask(T))
25+
Base.floatmax(::Type{T}) where T<:NO_NAN_OR_INF = reinterpret(T, exponent_mask(T) | mantissa_mask(T))
2326

2427
# E4M3 (MX): no Infs; only mantissa == 111 at exp=1111 is NaN
2528
nan(::Type{T}) where T<:MX_E4M3 = reinterpret(T, exponent_mask(T) | mantissa_mask(T))
@@ -35,7 +38,7 @@ nan(::Type{MX_E8M0}) = reinterpret(MX_E8M0, 0xff)
3538
# Float32 conversion for MX variants:
3639
# - exp=all-ones is "normal" except for the MX NaN sentinel(s)
3740
# - otherwise identical mapping as IEEE
38-
function _float32(x::T) where {T<:MXMicrofloat}
41+
function _float32(x::T) where {T<:MicroscaledMicrofloat}
3942
T isa MX_E8M0 && reinterpret(UInt8, x) == 0xff && return NaN32
4043

4144
sgn = UInt32(right_aligned_sign(x))
@@ -75,7 +78,7 @@ function _float32(x::T) where {T<:MXMicrofloat}
7578
end
7679

7780
# Saturating to_microfloat tables for MX (no Infs; overflow -> ±floatmax)
78-
function create_base_shifttable(::Type{T}) where {T<:MXMicrofloat}
81+
function create_base_shifttable(::Type{T}) where {T<:MicroscaledMicrofloat}
7982
basetable = Vector{T}(undef, 512)
8083
shifttable = Vector{UInt8}(undef, 512)
8184

@@ -112,5 +115,5 @@ function create_base_shifttable(::Type{T}) where {T<:MXMicrofloat}
112115
end
113116

114117
# Saturating bounds for MX: use finite extrema
115-
Base.typemax(::Type{T}) where {S,E,M,T<:MXMicrofloat{S,E,M}} = floatmax(T)
116-
Base.typemin(::Type{T}) where {S,E,M,T<:MXMicrofloat{S,E,M}} = ifelse(n_sign_bits(T) == 0, zero(T), -floatmax(T))
118+
Base.typemax(::Type{T}) where {S,E,M,T<:MicroscaledMicrofloat{S,E,M}} = floatmax(T)
119+
Base.typemin(::Type{T}) where {S,E,M,T<:MicroscaledMicrofloat{S,E,M}} = ifelse(n_sign_bits(T) == 0, zero(T), -floatmax(T))

test/MX/MX_compliance.jl

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@testset "FP8" begin
99

1010
@testset "E4M3" begin
11-
E4M3 = Microfloat(1, 4, 3, :MX)
11+
E4M3 = Microfloat(1, 4, 3, MX)
1212

1313
@test Microfloats.bias(E4M3) == 7
1414

@@ -74,90 +74,90 @@
7474
@testset "FP6" begin
7575

7676
@testset "E2M3" begin
77-
E2M3 = Microfloat(1, 2, 3, :MX)
77+
E2M3 = Microfloat(1, 2, 3, MX)
7878

7979
@test Microfloats.bias(E2M3) == 1
8080

81-
@test isfinite(reinterpret(E2M3, 0b0_11_000_00))
82-
@test isfinite(reinterpret(E2M3, 0b1_11_000_00))
81+
@test isfinite(reinterpret(E2M3, 0b0_11_000))
82+
@test isfinite(reinterpret(E2M3, 0b1_11_000))
8383

8484
for i in 0b001:0b111
85-
@test isfinite(reinterpret(E2M3, 0b0_11_000_00 | i << 2))
86-
@test isfinite(reinterpret(E2M3, 0b1_11_000_00 | i << 2))
85+
@test isfinite(reinterpret(E2M3, 0b0_11_000 | i << 2))
86+
@test isfinite(reinterpret(E2M3, 0b1_11_000 | i << 2))
8787
end
8888

89-
@test iszero(reinterpret(E2M3, 0b0_00_000_00))
90-
@test iszero(reinterpret(E2M3, 0b1_00_000_00))
89+
@test iszero(reinterpret(E2M3, 0b0_00_000))
90+
@test iszero(reinterpret(E2M3, 0b1_00_000))
9191

92-
@test reinterpret(E2M3, 0b0_11_111_00) == 2^2 * 1.875
93-
@test reinterpret(E2M3, 0b1_11_111_00) == -2^2 * 1.875
92+
@test reinterpret(E2M3, 0b0_11_111) == 2^2 * 1.875
93+
@test reinterpret(E2M3, 0b1_11_111) == -2^2 * 1.875
9494

95-
@test reinterpret(E2M3, 0b0_01_000_00) == 2^0 * 1.0
96-
@test reinterpret(E2M3, 0b1_01_000_00) == -2^0 * 1.0
95+
@test reinterpret(E2M3, 0b0_01_000) == 2^0 * 1.0
96+
@test reinterpret(E2M3, 0b1_01_000) == -2^0 * 1.0
9797

98-
@test reinterpret(E2M3, 0b0_00_111_00) == 2^0 * 0.875
99-
@test reinterpret(E2M3, 0b1_00_111_00) == -2^0 * 0.875
98+
@test reinterpret(E2M3, 0b0_00_111) == 2^0 * 0.875
99+
@test reinterpret(E2M3, 0b1_00_111) == -2^0 * 0.875
100100

101-
@test reinterpret(E2M3, 0b0_00_001_00) == 2^0 * 0.125
102-
@test reinterpret(E2M3, 0b1_00_001_00) == -2^0 * 0.125
101+
@test reinterpret(E2M3, 0b0_00_001) == 2^0 * 0.125
102+
@test reinterpret(E2M3, 0b1_00_001) == -2^0 * 0.125
103103

104104
end
105105

106106
@testset "E3M2" begin
107-
E3M2 = Microfloat(1, 3, 2, :MX)
107+
E3M2 = Microfloat(1, 3, 2, MX)
108108

109109
@test Microfloats.bias(E3M2) == 3
110110

111-
@test isfinite(reinterpret(E3M2, 0b0_111_00_00))
112-
@test isfinite(reinterpret(E3M2, 0b1_111_00_00))
111+
@test isfinite(reinterpret(E3M2, 0b0_111_00))
112+
@test isfinite(reinterpret(E3M2, 0b1_111_00))
113113

114114
for i in 0b01:0b11
115-
@test isfinite(reinterpret(E3M2, 0b0_111_00_00 | i << 2))
116-
@test isfinite(reinterpret(E3M2, 0b1_111_00_00 | i << 2))
115+
@test isfinite(reinterpret(E3M2, 0b0_111_00 | i << 2))
116+
@test isfinite(reinterpret(E3M2, 0b1_111_00 | i << 2))
117117
end
118118

119-
@test iszero(reinterpret(E3M2, 0b0_000_00_00))
120-
@test iszero(reinterpret(E3M2, 0b1_000_00_00))
119+
@test iszero(reinterpret(E3M2, 0b0_000_00))
120+
@test iszero(reinterpret(E3M2, 0b1_000_00))
121121

122-
@test reinterpret(E3M2, 0b0_111_11_00) == 2^4 * 1.75
123-
@test reinterpret(E3M2, 0b1_111_11_00) == -2^4 * 1.75
122+
@test reinterpret(E3M2, 0b0_111_11) == 2^4 * 1.75
123+
@test reinterpret(E3M2, 0b1_111_11) == -2^4 * 1.75
124124

125-
@test reinterpret(E3M2, 0b0_001_00_00) == 2^-2 * 1.0
126-
@test reinterpret(E3M2, 0b1_001_00_00) == -2^-2 * 1.0
125+
@test reinterpret(E3M2, 0b0_001_00) == 2^-2 * 1.0
126+
@test reinterpret(E3M2, 0b1_001_00) == -2^-2 * 1.0
127127

128-
@test reinterpret(E3M2, 0b0_000_11_00) == 2^-2 * 0.75
129-
@test reinterpret(E3M2, 0b1_000_11_00) == -2^-2 * 0.75
128+
@test reinterpret(E3M2, 0b0_000_11) == 2^-2 * 0.75
129+
@test reinterpret(E3M2, 0b1_000_11) == -2^-2 * 0.75
130130

131-
@test reinterpret(E3M2, 0b0_000_01_00) == 2^-2 * 0.25
132-
@test reinterpret(E3M2, 0b1_000_01_00) == -2^-2 * 0.25
131+
@test reinterpret(E3M2, 0b0_000_01) == 2^-2 * 0.25
132+
@test reinterpret(E3M2, 0b1_000_01) == -2^-2 * 0.25
133133
end
134134

135135
end
136136

137137
@testset "FP4" begin
138138

139139
@testset "E2M1" begin
140-
E2M1 = Microfloat(1, 2, 1, :MX)
140+
E2M1 = Microfloat(1, 2, 1, MX)
141141

142142
@test Microfloats.bias(E2M1) == 1
143143

144-
@test isfinite(reinterpret(E2M1, 0b0_11_0_0000))
145-
@test isfinite(reinterpret(E2M1, 0b1_11_0_0000))
144+
@test isfinite(reinterpret(E2M1, 0b0_11_0))
145+
@test isfinite(reinterpret(E2M1, 0b1_11_0))
146146

147-
@test isfinite(reinterpret(E2M1, 0b0_11_1_0000))
148-
@test isfinite(reinterpret(E2M1, 0b1_11_1_0000))
147+
@test isfinite(reinterpret(E2M1, 0b0_11_1))
148+
@test isfinite(reinterpret(E2M1, 0b1_11_1))
149149

150-
@test iszero(reinterpret(E2M1, 0b0_00_0_0000))
151-
@test iszero(reinterpret(E2M1, 0b1_00_0_0000))
150+
@test iszero(reinterpret(E2M1, 0b0_00_0))
151+
@test iszero(reinterpret(E2M1, 0b1_00_0))
152152

153-
@test reinterpret(E2M1, 0b0_11_1_0000) == 2^2 * 1.5
154-
@test reinterpret(E2M1, 0b1_11_1_0000) == -2^2 * 1.5
153+
@test reinterpret(E2M1, 0b0_11_1) == 2^2 * 1.5
154+
@test reinterpret(E2M1, 0b1_11_1) == -2^2 * 1.5
155155

156-
@test reinterpret(E2M1, 0b0_01_0_0000) == 2^0 * 1.0
157-
@test reinterpret(E2M1, 0b1_01_0_0000) == -2^0 * 1.0
156+
@test reinterpret(E2M1, 0b0_01_0) == 2^0 * 1.0
157+
@test reinterpret(E2M1, 0b1_01_0) == -2^0 * 1.0
158158

159-
@test reinterpret(E2M1, 0b0_00_1_0000) == 2^0 * 0.5
160-
@test reinterpret(E2M1, 0b1_00_1_0000) == -2^0 * 0.5
159+
@test reinterpret(E2M1, 0b0_00_1) == 2^0 * 0.5
160+
@test reinterpret(E2M1, 0b1_00_1) == -2^0 * 0.5
161161
end
162162

163163
end
@@ -168,7 +168,7 @@
168168

169169
# arithmetic not yet supported for unsigned microfloats
170170
@testset "E8M0" begin
171-
E8M0 = Microfloat(0, 8, 0, :MX)
171+
E8M0 = Microfloat(0, 8, 0, MX)
172172

173173
@test Microfloats.bias(E8M0) == 127
174174

test/MX/MX_properties.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
const E4M3 = Microfloat(1, 4, 3, :MX)
2-
const E5M2 = Microfloat(1, 5, 2, :MX)
3-
const E3M2 = Microfloat(1, 3, 2, :MX)
4-
const E2M3 = Microfloat(1, 2, 3, :MX)
5-
const E2M1 = Microfloat(1, 2, 1, :MX)
6-
const E8M0 = Microfloat(0, 8, 0, :MX)
1+
const E4M3 = Microfloat(1, 4, 3, MX)
2+
const E5M2 = Microfloat(1, 5, 2, MX)
3+
const E3M2 = Microfloat(1, 3, 2, MX)
4+
const E2M3 = Microfloat(1, 2, 3, MX)
5+
const E2M1 = Microfloat(1, 2, 1, MX)
6+
const E8M0 = Microfloat(0, 8, 0, MX)
77

88
@testset "MX: no Infs" begin
99
for T in (E4M3, E3M2, E2M3, E2M1, E8M0)
@@ -27,8 +27,8 @@ end
2727
maxm = UInt8((UInt16(1) << nm) - 1)
2828
for s in (UInt8(0), sm)
2929
for mv in UInt8(0):maxm
30-
m = mv << mo
31-
x = reinterpret(T, s | em | m)
30+
m = (mv << mo) & mm
31+
x = reinterpret(T, (s & sm) | em | m)
3232
if m == mm
3333
@test isnan(x)
3434
else
@@ -45,11 +45,12 @@ end
4545
sm = UInt8(Microfloats.sign_mask(T))
4646
mo = Microfloats.mantissa_offset(T)
4747
nm = Microfloats.n_mantissa_bits(T)
48+
mm = UInt8(Microfloats.mantissa_mask(T))
4849
maxm = UInt8((UInt16(1) << nm) - 1)
4950
for s in (UInt8(0), sm)
5051
for mv in UInt8(0):maxm
51-
m = mv << mo
52-
x = reinterpret(T, s | em | m)
52+
m = (mv << mo) & mm
53+
x = reinterpret(T, (s & sm) | em | m)
5354
@test isfinite(x)
5455
@test !isnan(x)
5556
end
@@ -69,11 +70,10 @@ end
6970
@testset "MX: round-trip via Float32 preserves bits (canonical encodings)" begin
7071
for T in (E4M3, E5M2, E3M2, E2M3, E2M1, E8M0)
7172
@testset "$T" begin
72-
mshift = Microfloats.mantissa_offset(T)
73-
mmask = UInt8(Microfloats.mantissa_mask(T))
73+
used_mask = UInt8(Microfloats.sign_mask(T) | Microfloats.exponent_mask(T) | Microfloats.mantissa_mask(T))
7474
for u in UInt8(0):UInt8(0xff)
7575
# Only test canonical encodings where mantissa padding bits are zero
76-
(u & ~mmask) != (u & ~mmask & ~(UInt8(1)<<mshift - UInt8(1))) && continue
76+
(u & ~used_mask) != 0x00 && continue
7777
x = reinterpret(T, u)
7878
y = T(Float32(x))
7979
@test y x
@@ -170,10 +170,9 @@ end
170170
for u in UInt8(0):UInt8(0xff)
171171
x = reinterpret(T, u)
172172
isnan(x) && continue
173-
# Only include canonical encodings
174-
mshift = Microfloats.mantissa_offset(T)
175-
mmask = UInt8(Microfloats.mantissa_mask(T))
176-
(u & ~mmask) != (u & ~mmask & ~(UInt8(1)<<mshift - UInt8(1))) && continue
173+
# Only include canonical encodings: padding bits outside fields are zero
174+
used_mask = UInt8(Microfloats.sign_mask(T) | Microfloats.exponent_mask(T) | Microfloats.mantissa_mask(T))
175+
(u & ~used_mask) != 0x00 && continue
177176
push!(vals, (u, Float32(x), x))
178177
end
179178
sort!(vals, by = t -> t[2])

0 commit comments

Comments
 (0)