diff --git a/packages/typed-binary/src/describe/index.ts b/packages/typed-binary/src/describe/index.ts index 7c10d99..afbd354 100644 --- a/packages/typed-binary/src/describe/index.ts +++ b/packages/typed-binary/src/describe/index.ts @@ -5,6 +5,7 @@ import { ByteSchema, CharsSchema, DynamicArraySchema, + Float16Schema, Float32Schema, GenericObjectSchema, Int32Schema, @@ -36,6 +37,8 @@ export const i32 = new Int32Schema(); export const u32 = new Uint32Schema(); +export const f16 = new Float16Schema(); + export const f32 = new Float32Schema(); export const chars = (length: T) => new CharsSchema(length); diff --git a/packages/typed-binary/src/io/bufferIOBase.ts b/packages/typed-binary/src/io/bufferIOBase.ts index 9362e1c..86b5815 100644 --- a/packages/typed-binary/src/io/bufferIOBase.ts +++ b/packages/typed-binary/src/io/bufferIOBase.ts @@ -17,7 +17,8 @@ export class BufferIOBase { protected readonly uint8View: Uint8Array; protected readonly helperInt32View: Int32Array; protected readonly helperUint32View: Uint32Array; - protected readonly helperFloatView: Float32Array; + protected readonly helperUint16View: Uint16Array; + protected readonly helperFloat32View: Float32Array; protected readonly helperByteView: Uint8Array; protected readonly switchEndianness: boolean; @@ -43,8 +44,9 @@ export class BufferIOBase { const helperBuffer = new ArrayBuffer(4); this.helperInt32View = new Int32Array(helperBuffer); this.helperUint32View = new Uint32Array(helperBuffer); - this.helperFloatView = new Float32Array(helperBuffer); + this.helperFloat32View = new Float32Array(helperBuffer); this.helperByteView = new Uint8Array(helperBuffer); + this.helperUint16View = new Uint16Array(helperBuffer); } get currentByteOffset() { diff --git a/packages/typed-binary/src/io/bufferReader.ts b/packages/typed-binary/src/io/bufferReader.ts index 3d1fcda..c845592 100644 --- a/packages/typed-binary/src/io/bufferReader.ts +++ b/packages/typed-binary/src/io/bufferReader.ts @@ -1,4 +1,5 @@ import { BufferIOBase } from './bufferIOBase'; +import { float16ToNumber } from './float16converter'; import type { ISerialInput } from './types'; import { unwrapBuffer } from './unwrapBuffer'; @@ -27,10 +28,16 @@ export class BufferReader extends BufferIOBase implements ISerialInput { return this.uint8View[this.byteOffset++]; } + readFloat16() { + this.copyInputToHelper(2); + + return float16ToNumber(this.helperUint16View); + } + readFloat32() { this.copyInputToHelper(4); - return this.helperFloatView[0]; + return this.helperFloat32View[0]; } readInt32() { diff --git a/packages/typed-binary/src/io/bufferWriter.ts b/packages/typed-binary/src/io/bufferWriter.ts index 9ee50c1..f7cdd55 100644 --- a/packages/typed-binary/src/io/bufferWriter.ts +++ b/packages/typed-binary/src/io/bufferWriter.ts @@ -1,4 +1,5 @@ import { BufferIOBase } from './bufferIOBase'; +import { numberToFloat16 } from './float16converter'; import type { ISerialOutput } from './types'; import { unwrapBuffer } from './unwrapBuffer'; @@ -26,6 +27,12 @@ export class BufferWriter extends BufferIOBase implements ISerialOutput { this.uint8View[this.byteOffset++] = Math.floor(value) % 256; } + writeFloat16(value: number): void { + this.helperUint16View[0] = numberToFloat16(value)[0]; + + this.copyHelperToOutput(2); + } + writeInt32(value: number) { this.helperInt32View[0] = Math.floor(value); @@ -39,7 +46,7 @@ export class BufferWriter extends BufferIOBase implements ISerialOutput { } writeFloat32(value: number) { - this.helperFloatView[0] = value; + this.helperFloat32View[0] = value; this.copyHelperToOutput(4); } diff --git a/packages/typed-binary/src/io/float16converter.ts b/packages/typed-binary/src/io/float16converter.ts new file mode 100644 index 0000000..c885433 --- /dev/null +++ b/packages/typed-binary/src/io/float16converter.ts @@ -0,0 +1,36 @@ +export function numberToFloat16(value: number): Uint16Array { + // conversion according to IEEE 754 binary16 format + if (value === 0) return new Uint16Array([0]); + if (Number.isNaN(value)) return new Uint16Array([0x7e00]); + if (!Number.isFinite(value)) + return new Uint16Array([value > 0 ? 0x7c00 : 0xfc00]); + + const sign = value < 0 ? 1 : 0; + const absValue = Math.abs(value); + const exponent = Math.floor(Math.log2(absValue)); + const mantissa = absValue / 2 ** exponent - 1; + const biasedExponent = exponent + 15; + const mantissaBits = Math.floor(mantissa * 1024); + const float16 = (sign << 15) | (biasedExponent << 10) | mantissaBits; + const uint16Array = new Uint16Array(1); + uint16Array[0] = float16; + return uint16Array; +} + +export function float16ToNumber(uint16Array: Uint16Array): number { + const float16 = uint16Array[0]; + const sign = (float16 & 0x8000) >> 15; + const exponent = (float16 & 0x7c00) >> 10; + const mantissa = float16 & 0x3ff; + if (exponent === 0) { + return sign === 0 ? mantissa / 1024 : -mantissa / 1024; + } + if (exponent === 31) { + return mantissa === 0 + ? sign === 0 + ? Number.POSITIVE_INFINITY + : Number.NEGATIVE_INFINITY + : Number.NaN; + } + return (sign === 0 ? 1 : -1) * (1 + mantissa / 1024) * 2 ** (exponent - 15); +} diff --git a/packages/typed-binary/src/io/types.ts b/packages/typed-binary/src/io/types.ts index 9a91402..a375d4a 100644 --- a/packages/typed-binary/src/io/types.ts +++ b/packages/typed-binary/src/io/types.ts @@ -8,6 +8,7 @@ export interface ISerialInput { readInt32(): number; readUint32(): number; readFloat32(): number; + readFloat16(): number; readString(): string; readSlice(bufferView: BufferView, offset: number, byteLength: number): void; seekTo(offset: number): void; @@ -22,6 +23,7 @@ export interface ISerialOutput { writeInt32(value: number): void; writeUint32(value: number): void; writeFloat32(value: number): void; + writeFloat16(value: number): void; writeString(value: string): void; writeSlice(bufferView: BufferView): void; seekTo(offset: number): void; diff --git a/packages/typed-binary/src/structure/baseTypes.ts b/packages/typed-binary/src/structure/baseTypes.ts index b693738..8b9dbaa 100644 --- a/packages/typed-binary/src/structure/baseTypes.ts +++ b/packages/typed-binary/src/structure/baseTypes.ts @@ -157,6 +157,30 @@ export class Uint32Schema extends Schema { // FLOAT //// +export class Float16Schema extends Schema { + /** + * The maximum number of bytes this schema can take up. + * + * Alias for `.measure(MaxValue).size` + */ + readonly maxSize = 2; + + read(input: ISerialInput): number { + return input.readFloat16(); + } + + write(output: ISerialOutput, value: number): void { + output.writeFloat16(value); + } + + measure( + _: number | MaxValue, + measurer: IMeasurer = new Measurer(), + ): IMeasurer { + return measurer.add(2); + } +} + export class Float32Schema extends Schema { /** * The maximum number of bytes this schema can take up. diff --git a/packages/typed-binary/src/structure/index.ts b/packages/typed-binary/src/structure/index.ts index 89e8c77..839f547 100644 --- a/packages/typed-binary/src/structure/index.ts +++ b/packages/typed-binary/src/structure/index.ts @@ -17,6 +17,7 @@ export { ByteSchema, Int32Schema, Uint32Schema, + Float16Schema, Float32Schema, ArraySchema, CharsSchema, diff --git a/packages/typed-binary/src/test/float.test.ts b/packages/typed-binary/src/test/float.test.ts index 577c249..ad31168 100644 --- a/packages/typed-binary/src/test/float.test.ts +++ b/packages/typed-binary/src/test/float.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it } from 'vitest'; -import { f32 } from '../describe'; +import { f16, f32 } from '../describe'; import { encodeAndDecode } from './helpers/mock'; import { randBetween } from './random'; @@ -12,3 +12,73 @@ describe('Float32Schema', () => { expect(decoded).to.closeTo(value, 0.01); }); }); + +describe('Float16Schema', () => { + it('should encode and decode a f16 value near or equal a power of 2 with high precision', () => { + // near 2^-10 + const value1 = 0.000976; + // near 2^-5 + const value2 = 0.031; + // 2^-2 + const value3 = 0.25; + // 2^0 + const value4 = 1; + // 2^4 + const value5 = 16; + // 2^8 + const value6 = 256; + // 2^12 + const value7 = 4096; + // 2^15 × (1 + 1023/1024) - largest representable value + const value8 = 65504; + + const decoded1 = encodeAndDecode(f16, value1); + const decoded2 = encodeAndDecode(f16, value2); + const decoded3 = encodeAndDecode(f16, value3); + const decoded4 = encodeAndDecode(f16, value4); + const decoded5 = encodeAndDecode(f16, value5); + const decoded6 = encodeAndDecode(f16, value6); + const decoded7 = encodeAndDecode(f16, value7); + const decoded8 = encodeAndDecode(f16, value8); + + expect(decoded1).to.closeTo(value1, 0.0001); + expect(decoded2).to.closeTo(value2, 0.0001); + expect(decoded3).to.closeTo(value3, 0.0001); + expect(decoded4).to.closeTo(value4, 0.0001); + expect(decoded5).to.closeTo(value5, 0.0001); + expect(decoded6).to.closeTo(value6, 0.0001); + expect(decoded7).to.closeTo(value7, 0.0001); + expect(decoded8).to.closeTo(value8, 0.0001); + }); + + it('should encode and decode a f16 value', () => { + const value1 = 5472.5; // precision should be 4 + const value2 = 145; // precision should be 2^-2 + const value3 = 0.34; // precision should be 2^-12 + const value4 = 21877.5; // precision should be 16 + + const decoded1 = encodeAndDecode(f16, value1); + const decoded2 = encodeAndDecode(f16, value2); + const decoded3 = encodeAndDecode(f16, value3); + const decoded4 = encodeAndDecode(f16, value4); + + expect(decoded1).to.closeTo(value1, 4); + expect(decoded2).to.closeTo(value2, 0.25); + expect(decoded3).to.closeTo(value3, 0.000976); + expect(decoded4).to.closeTo(value4, 16); + }); + + it('should handle NaN and Infinity', () => { + const value1 = Number.POSITIVE_INFINITY; + const value2 = Number.NEGATIVE_INFINITY; + const value3 = Number.NaN; + + const decoded1 = encodeAndDecode(f16, value1); + const decoded2 = encodeAndDecode(f16, value2); + const decoded3 = encodeAndDecode(f16, value3); + + expect(decoded1).to.equal(value1); + expect(decoded2).to.equal(value2); + expect(decoded3).to.be.NaN; + }); +});