Skip to content

Commit

Permalink
implementation of float16
Browse files Browse the repository at this point in the history
  • Loading branch information
reczkok committed Oct 14, 2024
1 parent 0202929 commit a9028be
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 5 deletions.
3 changes: 3 additions & 0 deletions packages/typed-binary/src/describe/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
ByteSchema,
CharsSchema,
DynamicArraySchema,
Float16Schema,
Float32Schema,
GenericObjectSchema,
Int32Schema,
Expand Down Expand Up @@ -36,6 +37,8 @@ export const i32 = new Int32Schema();

export const u32 = new Uint32Schema();

export const f16 = new Float16Schema();

export const f32 = new Float32Schema();

export const chars = <T extends number>(length: T) => new CharsSchema(length);
Expand Down
6 changes: 4 additions & 2 deletions packages/typed-binary/src/io/bufferIOBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ export class BufferIOBase {
protected readonly uint8View: Uint8Array;
protected readonly helperInt32View: Int32Array;
protected readonly helperUint32View: Uint32Array;
protected readonly helperFloatView: Float32Array;
protected readonly helperFloat16View: Uint16Array;
protected readonly helperFloat32View: Float32Array;
protected readonly helperByteView: Uint8Array;
protected readonly switchEndianness: boolean;

Expand All @@ -43,8 +44,9 @@ export class BufferIOBase {
const helperBuffer = new ArrayBuffer(4);
this.helperInt32View = new Int32Array(helperBuffer);
this.helperUint32View = new Uint32Array(helperBuffer);
this.helperFloatView = new Float32Array(helperBuffer);
this.helperFloat32View = new Float32Array(helperBuffer);
this.helperByteView = new Uint8Array(helperBuffer);
this.helperFloat16View = new Uint16Array(helperBuffer);
}

get currentByteOffset() {
Expand Down
9 changes: 8 additions & 1 deletion packages/typed-binary/src/io/bufferReader.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { BufferIOBase } from './bufferIOBase';
import { float16ToNumber } from './float16converter';
import type { ISerialInput } from './types';
import { unwrapBuffer } from './unwrapBuffer';

Expand Down Expand Up @@ -27,10 +28,16 @@ export class BufferReader extends BufferIOBase implements ISerialInput {
return this.uint8View[this.byteOffset++];
}

readFloat16() {
this.copyInputToHelper(2);

return float16ToNumber(this.helperFloat16View);
}

readFloat32() {
this.copyInputToHelper(4);

return this.helperFloatView[0];
return this.helperFloat32View[0];
}

readInt32() {
Expand Down
9 changes: 8 additions & 1 deletion packages/typed-binary/src/io/bufferWriter.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { BufferIOBase } from './bufferIOBase';
import { numberToFloat16 } from './float16converter';
import type { ISerialOutput } from './types';
import { unwrapBuffer } from './unwrapBuffer';

Expand Down Expand Up @@ -26,6 +27,12 @@ export class BufferWriter extends BufferIOBase implements ISerialOutput {
this.uint8View[this.byteOffset++] = Math.floor(value) % 256;
}

writeFloat16(value: number): void {
this.helperFloat16View[0] = numberToFloat16(value)[0];

this.copyHelperToOutput(2);
}

writeInt32(value: number) {
this.helperInt32View[0] = Math.floor(value);

Expand All @@ -39,7 +46,7 @@ export class BufferWriter extends BufferIOBase implements ISerialOutput {
}

writeFloat32(value: number) {
this.helperFloatView[0] = value;
this.helperFloat32View[0] = value;

this.copyHelperToOutput(4);
}
Expand Down
36 changes: 36 additions & 0 deletions packages/typed-binary/src/io/float16converter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
export function numberToFloat16(value: number): Uint16Array {
// conversion according to IEEE 754 binary16 format
if (value === 0) return new Uint16Array([0]);
if (Number.isNaN(value)) return new Uint16Array([0x7e00]);
if (!Number.isFinite(value))
return new Uint16Array([value > 0 ? 0x7c00 : 0xfc00]);

const sign = value < 0 ? 1 : 0;
const absValue = Math.abs(value);
const exponent = Math.floor(Math.log2(absValue));
const mantissa = absValue / 2 ** exponent - 1;
const biasedExponent = exponent + 15;
const mantissaBits = Math.floor(mantissa * 1024);
const float16 = (sign << 15) | (biasedExponent << 10) | mantissaBits;
const uint16Array = new Uint16Array(1);
uint16Array[0] = float16;
return uint16Array;
}

export function float16ToNumber(uint16Array: Uint16Array): number {
const float16 = uint16Array[0];
const sign = (float16 & 0x8000) >> 15;
const exponent = (float16 & 0x7c00) >> 10;
const mantissa = float16 & 0x3ff;
if (exponent === 0) {
return sign === 0 ? mantissa / 1024 : -mantissa / 1024;
}
if (exponent === 31) {
return mantissa === 0
? sign === 0
? Number.POSITIVE_INFINITY
: Number.NEGATIVE_INFINITY
: Number.NaN;
}
return (sign === 0 ? 1 : -1) * (1 + mantissa / 1024) * 2 ** (exponent - 15);
}
2 changes: 2 additions & 0 deletions packages/typed-binary/src/io/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface ISerialInput {
readInt32(): number;
readUint32(): number;
readFloat32(): number;
readFloat16(): number;
readString(): string;
readSlice(bufferView: BufferView, offset: number, byteLength: number): void;
seekTo(offset: number): void;
Expand All @@ -22,6 +23,7 @@ export interface ISerialOutput {
writeInt32(value: number): void;
writeUint32(value: number): void;
writeFloat32(value: number): void;
writeFloat16(value: number): void;
writeString(value: string): void;
writeSlice(bufferView: BufferView): void;
seekTo(offset: number): void;
Expand Down
24 changes: 24 additions & 0 deletions packages/typed-binary/src/structure/baseTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,30 @@ export class Uint32Schema extends Schema<number> {
// FLOAT
////

export class Float16Schema extends Schema<number> {
/**
* The maximum number of bytes this schema can take up.
*
* Alias for `.measure(MaxValue).size`
*/
readonly maxSize = 2;

read(input: ISerialInput): number {
return input.readFloat16();
}

write(output: ISerialOutput, value: number): void {
output.writeFloat16(value);
}

measure(
_: number | MaxValue,
measurer: IMeasurer = new Measurer(),
): IMeasurer {
return measurer.add(2);
}
}

export class Float32Schema extends Schema<number> {
/**
* The maximum number of bytes this schema can take up.
Expand Down
1 change: 1 addition & 0 deletions packages/typed-binary/src/structure/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export {
ByteSchema,
Int32Schema,
Uint32Schema,
Float16Schema,
Float32Schema,
ArraySchema,
CharsSchema,
Expand Down
58 changes: 57 additions & 1 deletion packages/typed-binary/src/test/float.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, expect, it } from 'vitest';

import { f32 } from '../describe';
import { f16, f32 } from '../describe';
import { encodeAndDecode } from './helpers/mock';
import { randBetween } from './random';

Expand All @@ -12,3 +12,59 @@ describe('Float32Schema', () => {
expect(decoded).to.closeTo(value, 0.01);
});
});

describe('Float16Schema', () => {
it('should encode and decode a f16 value near or equal a power of 2 with high precision', () => {
// near 2^-10
const value1 = 0.000976;
// near 2^-5
const value2 = 0.031;
// 2^-2
const value3 = 0.25;
// 2^0
const value4 = 1;
// 2^4
const value5 = 16;
// 2^8
const value6 = 256;
// 2^12
const value7 = 4096;
// 2^15 × (1 + 1023/1024) - largest representable value
const value8 = 65504;

const decoded1 = encodeAndDecode(f16, value1);
const decoded2 = encodeAndDecode(f16, value2);
const decoded3 = encodeAndDecode(f16, value3);
const decoded4 = encodeAndDecode(f16, value4);
const decoded5 = encodeAndDecode(f16, value5);
const decoded6 = encodeAndDecode(f16, value6);
const decoded7 = encodeAndDecode(f16, value7);
const decoded8 = encodeAndDecode(f16, value8);

expect(decoded1).to.closeTo(value1, 0.0001);
expect(decoded2).to.closeTo(value2, 0.0001);
expect(decoded3).to.closeTo(value3, 0.0001);
expect(decoded4).to.closeTo(value4, 0.0001);
expect(decoded5).to.closeTo(value5, 0.0001);
expect(decoded6).to.closeTo(value6, 0.0001);
expect(decoded7).to.closeTo(value7, 0.0001);
expect(decoded8).to.closeTo(value8, 0.0001);
});

it('should encode and decode a f16 value', () => {
const value1 = 5472.5; // precision should be 4
const value2 = 145; // precision should be 2^-2
const value3 = 0.34; // precision should be 2^-12
const value4 = 21877.5; // precision should be 16

const decoded1 = encodeAndDecode(f16, value1);
const decoded2 = encodeAndDecode(f16, value2);
const decoded3 = encodeAndDecode(f16, value3);
const decoded4 = encodeAndDecode(f16, value4);

expect(decoded1).to.closeTo(value1, 4);
expect(decoded2).to.closeTo(value2, 0.25);
expect(decoded3).to.closeTo(value3, 0.000976);
expect(decoded4).to.closeTo(value4, 16);
});
});

0 comments on commit a9028be

Please sign in to comment.