Skip to content

Commit d66c17b

Browse files
Support 4GB of memory in WASM backend (#7647)
Set max WASM memory to 4GB and fix negative int32 malloc indices.
1 parent eb1b329 commit d66c17b

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

tfjs-backend-wasm/src/backend_wasm.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ export class BackendWasm extends KernelBackend {
9292

9393
const size = util.sizeFromShape(shape);
9494
const numBytes = size * util.bytesPerElement(dtype);
95-
const memoryOffset = this.wasm._malloc(numBytes);
95+
96+
// `>>> 0` is needed for above 2GB allocations because wasm._malloc returns
97+
// a signed int32 instead of an unsigned int32.
98+
// https://v8.dev/blog/4gb-wasm-memory
99+
const memoryOffset = this.wasm._malloc(numBytes) >>> 0;
96100

97101
this.dataIdMap.set(dataId, {id, memoryOffset, shape, dtype, refCount});
98102

tfjs-backend-wasm/src/cc/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ KERNELS_WITH_KEEPALIVE = glob(
1212

1313
BASE_LINKOPTS = [
1414
"-s ALLOW_MEMORY_GROWTH=1",
15+
"-s MAXIMUM_MEMORY=4GB",
1516
"-s DEFAULT_LIBRARY_FUNCS_TO_INCLUDE=[]",
1617
"-s DISABLE_EXCEPTION_CATCHING=1",
1718
"-s FILESYSTEM=0",

tfjs-backend-wasm/src/index_test.ts

+24
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,30 @@ describeWithFlags('wasm read/write', ALL_ENVS, () => {
6464
// Tensor values should match.
6565
test_util.expectArraysClose(await t.data(), view);
6666
});
67+
68+
it('allocates more than two gigabytes', async () => {
69+
const size = 2**30 / 4; // 2**30 bytes (4 bytes per number) = 1GB
70+
71+
// Allocate 3 gigabytes.
72+
const t0 = tf.zeros([size], 'float32');
73+
const t1 = tf.ones([size], 'float32');
74+
const t2 = t1.mul(2);
75+
76+
// Helper function to check if all the values in a tensor equal an expected
77+
// value.
78+
async function check(tensor: tf.Tensor, name: string, val: number) {
79+
const arr = await tensor.data();
80+
for (let i = 0; i < size; i++) {
81+
if (arr[i] !== val) {
82+
throw new Error(`${name}[${i}] == ${arr[i]} but should be ${val}`);
83+
}
84+
}
85+
}
86+
87+
await check(t0, 't0', 0);
88+
await check(t1, 't1', 1);
89+
await check(t2, 't2', 2);
90+
});
6791
});
6892

6993
describeWithFlags('wasm init', BROWSER_ENVS, () => {

0 commit comments

Comments
 (0)