diff --git a/sdks/typescript/src/apache_beam/worker/state.ts b/sdks/typescript/src/apache_beam/worker/state.ts index 5a340cbb64f0..dee7a866c2c7 100644 --- a/sdks/typescript/src/apache_beam/worker/state.ts +++ b/sdks/typescript/src/apache_beam/worker/state.ts @@ -46,12 +46,102 @@ export interface StateProvider { } // TODO: (Advanced) Cross-bundle caching. +/** + * Wrapper for cached values that tracks their weight (memory size). + */ +interface WeightedCacheEntry { + entry: MaybePromise; + weight: number; +} + +/** + * Estimates the memory size of a value in bytes. + * This is a simplified estimation - actual memory usage may vary. + */ +function estimateSize(value: any): number { + if (value === null || value === undefined) { + return 8; + } + + const type = typeof value; + + if (type === "boolean") { + return 4; + } + if (type === "number") { + return 8; + } + if (type === "string") { + // Each character is 2 bytes in JavaScript (UTF-16) + overhead + return 40 + value.length * 2; + } + if (value instanceof Uint8Array || value instanceof Buffer) { + return 40 + value.length; + } + if (Array.isArray(value)) { + let size = 40; // Array overhead + for (const item of value) { + size += estimateSize(item); + } + return size; + } + if (type === "object") { + let size = 40; // Object overhead + for (const key of Object.keys(value)) { + size += estimateSize(key) + estimateSize(value[key]); + } + return size; + } + + // Default for unknown types + return 64; +} + +// Default cache size: 100MB +const DEFAULT_MAX_CACHE_WEIGHT = 100 * 1024 * 1024; + export class CachingStateProvider implements StateProvider { underlying: StateProvider; - cache: Map> = new Map(); + cache: Map> = new Map(); + maxCacheWeight: number; + currentWeight: number = 0; - constructor(underlying: StateProvider) { + constructor( + underlying: StateProvider, + maxCacheWeight: number = DEFAULT_MAX_CACHE_WEIGHT, + ) { this.underlying = underlying; + this.maxCacheWeight = maxCacheWeight; + } + + /** + * Evicts least recently used entries until the cache is under the weight limit. + * JavaScript Maps preserve insertion order, so the first entry is the oldest. + */ + private evictIfNeeded() { + while (this.currentWeight > this.maxCacheWeight && this.cache.size > 0) { + // Remove the first (oldest) entry + const firstKey = this.cache.keys().next().value; + if (firstKey !== undefined) { + const evicted = this.cache.get(firstKey); + if (evicted !== undefined) { + this.currentWeight -= evicted.weight; + } + this.cache.delete(firstKey); + } + } + } + + /** + * Moves a cache entry to the end (most recently used) by deleting and re-adding it. + * This maintains LRU order: most recently accessed items are at the end. + */ + private touchCacheEntry(cacheKey: string) { + const value = this.cache.get(cacheKey); + if (value !== undefined) { + this.cache.delete(cacheKey); + this.cache.set(cacheKey, value); + } } getState(stateKey: fnApi.StateKey, decode: (data: Uint8Array) => T) { @@ -62,21 +152,40 @@ export class CachingStateProvider implements StateProvider { "base64", ); if (this.cache.has(cacheKey)) { - return this.cache.get(cacheKey)!; + // Cache hit: move to end (most recently used) + this.touchCacheEntry(cacheKey); + return this.cache.get(cacheKey)!.entry; } + // Cache miss: fetch from underlying provider let result = this.underlying.getState(stateKey, decode); - const this_ = this; if (result.type === "promise") { result = { type: "promise", promise: result.promise.then((value) => { - this_.cache.set(cacheKey, { type: "value", value }); + // When promise resolves, update cache with resolved value + // Get the current entry to update its weight + const currentEntry = this.cache.get(cacheKey); + if (currentEntry !== undefined) { + // Remove old weight from total + this.currentWeight -= currentEntry.weight; + } + const resolvedWeight = estimateSize(value); + this.cache.set(cacheKey, { + entry: { type: "value", value }, + weight: resolvedWeight, + }); + this.currentWeight += resolvedWeight; + this.evictIfNeeded(); return value; }), }; } - // TODO: (Perf) Cache eviction. - this.cache.set(cacheKey, result); + // Estimate weight for the new entry + const weight = result.type === "value" ? estimateSize(result.value) : 64; // Promise placeholder weight + // Evict if needed before adding new entry + this.currentWeight += weight; + this.evictIfNeeded(); + this.cache.set(cacheKey, { entry: result, weight }); return result; } } diff --git a/sdks/typescript/test/state_provider_test.ts b/sdks/typescript/test/state_provider_test.ts new file mode 100644 index 000000000000..e830754129ab --- /dev/null +++ b/sdks/typescript/test/state_provider_test.ts @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from "assert"; +import { + CachingStateProvider, + StateProvider, + MaybePromise, +} from "../src/apache_beam/worker/state"; +import * as fnApi from "../src/apache_beam/proto/beam_fn_api"; + +/** + * Mock StateProvider for testing that tracks call counts. + */ +class MockStateProvider implements StateProvider { + callCount: number = 0; + values: Map = new Map(); + delayMs: number = 0; + + constructor(delayMs: number = 0) { + this.delayMs = delayMs; + } + + setValue(key: string, value: any) { + this.values.set(key, value); + } + + getState( + stateKey: fnApi.StateKey, + decode: (data: Uint8Array) => T, + ): MaybePromise { + this.callCount++; + const key = Buffer.from(fnApi.StateKey.toBinary(stateKey)).toString( + "base64", + ); + const value = this.values.get(key); + + if (this.delayMs > 0) { + return { + type: "promise", + promise: new Promise((resolve) => { + setTimeout(() => resolve(value), this.delayMs); + }), + }; + } else { + return { type: "value", value }; + } + } +} + +describe("CachingStateProvider", function () { + it("caches values and returns cached result on subsequent calls", function () { + const mockProvider = new MockStateProvider(); + // Use large weight limit to ensure no eviction for this test + const cache = new CachingStateProvider(mockProvider, 10 * 1024); + + const stateKey: fnApi.StateKey = { + type: { + oneofKind: "bagUserState", + bagUserState: fnApi.StateKey_BagUserState.create({ + transformId: "test", + userStateId: "state1", + window: new Uint8Array(0), + key: new Uint8Array(0), + }), + }, + }; + + const decode = (data: Uint8Array) => data.toString(); + + // Set value in mock + const testValue = "cached_value"; + const key = Buffer.from(fnApi.StateKey.toBinary(stateKey)).toString( + "base64", + ); + mockProvider.setValue(key, testValue); + + // First call should hit underlying provider + const result1 = cache.getState(stateKey, decode); + assert.equal(mockProvider.callCount, 1); + assert.equal(result1.type, "value"); + if (result1.type === "value") { + assert.equal(result1.value, testValue); + } + + // Second call should use cache + const result2 = cache.getState(stateKey, decode); + assert.equal(mockProvider.callCount, 1); // Still 1, not 2 + assert.equal(result2.type, "value"); + if (result2.type === "value") { + assert.equal(result2.value, testValue); + } + }); + + it("evicts least recently used entry when cache weight exceeds limit", function () { + const mockProvider = new MockStateProvider(); + // Each small string "valueX" is approximately 52 bytes (40 + 6*2) + // Set weight limit to hold approximately 3 entries + const cache = new CachingStateProvider(mockProvider, 180); + + const decode = (data: Uint8Array) => data.toString(); + + // Create 4 different state keys + const keys: fnApi.StateKey[] = []; + for (let i = 0; i < 4; i++) { + keys.push({ + type: { + oneofKind: "bagUserState", + bagUserState: fnApi.StateKey_BagUserState.create({ + transformId: "test", + userStateId: `state${i}`, + window: new Uint8Array(0), + key: new Uint8Array(0), + }), + }, + }); + } + + // Set values in mock + for (let i = 0; i < 4; i++) { + const key = Buffer.from(fnApi.StateKey.toBinary(keys[i])).toString( + "base64", + ); + mockProvider.setValue(key, `value${i}`); + } + + // Fill cache with 3 entries + cache.getState(keys[0], decode); + cache.getState(keys[1], decode); + cache.getState(keys[2], decode); + assert.equal(mockProvider.callCount, 3); + assert.equal(cache.cache.size, 3); + + // Access keys[0] to make it most recently used + cache.getState(keys[0], decode); + assert.equal(mockProvider.callCount, 3); // Still cached + + // Add 4th entry - should evict keys[1] (least recently used, not keys[0]) + cache.getState(keys[3], decode); + assert.equal(mockProvider.callCount, 4); + + // keys[1] should be evicted (not in cache) + const result1 = cache.getState(keys[1], decode); + assert.equal(mockProvider.callCount, 5); // Had to fetch again + assert.equal(result1.type, "value"); + if (result1.type === "value") { + assert.equal(result1.value, "value1"); + } + + // keys[0] should still be cached (was most recently used) + const result0 = cache.getState(keys[0], decode); + assert.equal(mockProvider.callCount, 5); // Still cached, no new call + assert.equal(result0.type, "value"); + if (result0.type === "value") { + assert.equal(result0.value, "value0"); + } + }); + + it("handles promise-based state fetches correctly", async function () { + const mockProvider = new MockStateProvider(10); // 10ms delay + // Use large weight limit to ensure no eviction for this test + const cache = new CachingStateProvider(mockProvider, 10 * 1024); + + const stateKey: fnApi.StateKey = { + type: { + oneofKind: "bagUserState", + bagUserState: fnApi.StateKey_BagUserState.create({ + transformId: "test", + userStateId: "async_state", + window: new Uint8Array(0), + key: new Uint8Array(0), + }), + }, + }; + + const decode = (data: Uint8Array) => data.toString(); + const key = Buffer.from(fnApi.StateKey.toBinary(stateKey)).toString( + "base64", + ); + mockProvider.setValue(key, "async_value"); + + // First call returns promise + const result1 = cache.getState(stateKey, decode); + assert.equal(result1.type, "promise"); + assert.equal(mockProvider.callCount, 1); + + // Wait for promise to resolve + if (result1.type === "promise") { + const value1 = await result1.promise; + assert.equal(value1, "async_value"); + + // Second call should return cached value (not promise) + const result2 = cache.getState(stateKey, decode); + assert.equal(result2.type, "value"); + assert.equal(mockProvider.callCount, 1); // Still only 1 call + if (result2.type === "value") { + assert.equal(result2.value, "async_value"); + } + } + }); + + it("respects custom maxCacheWeight and evicts based on memory size", function () { + const mockProvider = new MockStateProvider(); + // Set weight limit to hold approximately 2 small string entries + const cache = new CachingStateProvider(mockProvider, 120); + + const decode = (data: Uint8Array) => data.toString(); + + const keys: fnApi.StateKey[] = []; + for (let i = 0; i < 3; i++) { + keys.push({ + type: { + oneofKind: "bagUserState", + bagUserState: fnApi.StateKey_BagUserState.create({ + transformId: "test", + userStateId: `state${i}`, + window: new Uint8Array(0), + key: new Uint8Array(0), + }), + }, + }); + const key = Buffer.from(fnApi.StateKey.toBinary(keys[i])).toString( + "base64", + ); + mockProvider.setValue(key, `value${i}`); + } + + // Fill cache with 2 entries + cache.getState(keys[0], decode); + cache.getState(keys[1], decode); + assert.equal(cache.cache.size, 2); + + // Add 3rd entry - should evict oldest to stay under weight limit + cache.getState(keys[2], decode); + + // First entry should be evicted + cache.getState(keys[0], decode); + assert.equal(mockProvider.callCount, 4); // Had to fetch keys[0] again + }); + + it("tracks cache weight correctly", function () { + const mockProvider = new MockStateProvider(); + const cache = new CachingStateProvider(mockProvider, 10 * 1024); + + const decode = (data: Uint8Array) => data.toString(); + + const stateKey: fnApi.StateKey = { + type: { + oneofKind: "bagUserState", + bagUserState: fnApi.StateKey_BagUserState.create({ + transformId: "test", + userStateId: "state1", + window: new Uint8Array(0), + key: new Uint8Array(0), + }), + }, + }; + + const key = Buffer.from(fnApi.StateKey.toBinary(stateKey)).toString( + "base64", + ); + mockProvider.setValue(key, "test_value"); + + // Cache should start with 0 weight + assert.equal(cache.currentWeight, 0); + + // After adding an entry, weight should increase + cache.getState(stateKey, decode); + assert.ok(cache.currentWeight > 0); + }); +}); +