|
| 1 | +/* |
| 2 | + * This source file contains the code for proxying calls in the master thread to calls in the workers |
| 3 | + * by `.postMessage()`-ing. |
| 4 | + * |
| 5 | + * Keep in mind that this code can make or break the program's performance! Need to optimize more… |
| 6 | + */ |
| 7 | + |
| 8 | +import { Debugger } from "debug" |
| 9 | +import isSomeObservable from "is-observable" |
| 10 | +import { multicast, Observable, Subscription } from "observable-fns" |
| 11 | +import { MessageRelay } from "../types/common" |
| 12 | +import { |
| 13 | + ModuleMethods, |
| 14 | + ModuleProxy, |
| 15 | + ProxyableFunction |
| 16 | +} from "../types/master" |
| 17 | +import { |
| 18 | + CallCancelMessage, |
| 19 | + CallErrorMessage, |
| 20 | + CallInvocationMessage, |
| 21 | + CallResultMessage, |
| 22 | + CallRunningMessage, |
| 23 | + CommonMessageType |
| 24 | +} from "../types/messages" |
| 25 | +import { SerializedError, Serializer } from "../types/serializers" |
| 26 | +import { lookupLocalCallback, Callback } from "./callbacks" |
| 27 | +import { ObservablePromise } from "./observable-promise" |
| 28 | +import { isTransferDescriptor } from "./transferable" |
| 29 | + |
| 30 | +let nextCallID = 1 |
| 31 | + |
| 32 | +const activeSubscriptions = new Map<number, Subscription<any>>() |
| 33 | + |
| 34 | +const dedupe = <T>(array: T[]): T[] => Array.from(new Set(array)) |
| 35 | + |
| 36 | +const isCallCancelMessage = (data: any): data is CallCancelMessage => data && data.type === CommonMessageType.cancel |
| 37 | +const isCallErrorMessage = (data: any): data is CallErrorMessage => data && data.type === CommonMessageType.error |
| 38 | +const isCallResultMessage = (data: any): data is CallResultMessage => data && data.type === CommonMessageType.result |
| 39 | +const isCallRunningMessage = (data: any): data is CallRunningMessage => data && data.type === CommonMessageType.running |
| 40 | +const isInvocationMessage = (data: any): data is CallInvocationMessage => data && data.type === CommonMessageType.invoke |
| 41 | + |
| 42 | +function isZenObservable(thing: any): thing is Observable<any> { |
| 43 | + return thing && typeof thing === "object" && typeof thing.subscribe === "function" |
| 44 | +} |
| 45 | + |
| 46 | +/** |
| 47 | + * There are issues with `is-observable` not recognizing zen-observable's instances. |
| 48 | + * We are using `observable-fns`, but it's based on zen-observable, too. |
| 49 | + */ |
| 50 | +function isObservable(thing: any): thing is Observable<any> { |
| 51 | + return isSomeObservable(thing) || isZenObservable(thing) |
| 52 | +} |
| 53 | + |
| 54 | +function deconstructTransfer(thing: any) { |
| 55 | + return isTransferDescriptor(thing) |
| 56 | + ? { payload: thing.send, transferables: thing.transferables } |
| 57 | + : { payload: thing, transferables: undefined } |
| 58 | +} |
| 59 | + |
| 60 | +function postCallError(relay: MessageRelay, uid: number, rawError: SerializedError) { |
| 61 | + const { payload: error, transferables } = deconstructTransfer(rawError) |
| 62 | + const errorMessage: CallErrorMessage = { |
| 63 | + type: CommonMessageType.error, |
| 64 | + uid, |
| 65 | + error |
| 66 | + } |
| 67 | + relay.postMessage(errorMessage, transferables) |
| 68 | +} |
| 69 | + |
| 70 | +function postCallResult(relay: MessageRelay, uid: number, completed: boolean, resultValue?: any) { |
| 71 | + const { payload, transferables } = deconstructTransfer(resultValue) |
| 72 | + const resultMessage: CallResultMessage = { |
| 73 | + type: CommonMessageType.result, |
| 74 | + uid, |
| 75 | + complete: completed ? true : undefined, |
| 76 | + payload |
| 77 | + } |
| 78 | + relay.postMessage(resultMessage, transferables) |
| 79 | +} |
| 80 | + |
| 81 | +function postCallRunning(relay: MessageRelay, uid: number, resultType: CallRunningMessage["resultType"]) { |
| 82 | + const startMessage: CallRunningMessage = { |
| 83 | + type: CommonMessageType.running, |
| 84 | + uid, |
| 85 | + resultType |
| 86 | + } |
| 87 | + relay.postMessage(startMessage) |
| 88 | +} |
| 89 | + |
| 90 | +function createObservableForJob<ResultType>( |
| 91 | + relay: MessageRelay, |
| 92 | + serializer: Serializer, |
| 93 | + callID: number, |
| 94 | + debug: Debugger |
| 95 | +): Observable<ResultType> { |
| 96 | + return new Observable(observer => { |
| 97 | + let asyncType: "observable" | "promise" | undefined |
| 98 | + |
| 99 | + const messageHandler = ((event: MessageEvent) => { |
| 100 | + const message = event.data |
| 101 | + |
| 102 | + if (!message || message.uid !== callID) return |
| 103 | + debug(`Received message for running call ${callID}:`, message) |
| 104 | + |
| 105 | + if (isCallRunningMessage(message)) { |
| 106 | + asyncType = message.resultType |
| 107 | + } else if (isCallResultMessage(message)) { |
| 108 | + if (asyncType === "promise") { |
| 109 | + if (typeof message.payload !== "undefined") { |
| 110 | + observer.next(serializer.deserialize(message.payload, relay)) |
| 111 | + } |
| 112 | + observer.complete() |
| 113 | + relay.removeEventListener("message", messageHandler) |
| 114 | + } else { |
| 115 | + if (message.payload) { |
| 116 | + observer.next(serializer.deserialize(message.payload, relay)) |
| 117 | + } |
| 118 | + if (message.complete) { |
| 119 | + observer.complete() |
| 120 | + relay.removeEventListener("message", messageHandler) |
| 121 | + } |
| 122 | + } |
| 123 | + } else if (isCallErrorMessage(message)) { |
| 124 | + const error = serializer.deserialize(message.error as any, relay) |
| 125 | + if (asyncType === "promise" || !asyncType) { |
| 126 | + observer.error(error) |
| 127 | + } else { |
| 128 | + observer.error(error) |
| 129 | + } |
| 130 | + relay.removeEventListener("message", messageHandler) |
| 131 | + } |
| 132 | + }) as EventListener |
| 133 | + |
| 134 | + relay.addEventListener("message", messageHandler) |
| 135 | + |
| 136 | + return () => { |
| 137 | + if (asyncType === "observable" || !asyncType) { |
| 138 | + const cancelMessage: CallCancelMessage = { |
| 139 | + type: CommonMessageType.cancel, |
| 140 | + uid: callID |
| 141 | + } |
| 142 | + relay.postMessage(cancelMessage) |
| 143 | + } |
| 144 | + relay.removeEventListener("message", messageHandler) |
| 145 | + } |
| 146 | + }) |
| 147 | +} |
| 148 | + |
| 149 | +function prepareArguments(serializer: Serializer, rawArgs: any[]): { args: any[], transferables: Transferable[] } { |
| 150 | + if (rawArgs.length === 0) { |
| 151 | + // Exit early if possible |
| 152 | + return { |
| 153 | + args: [], |
| 154 | + transferables: [] |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + const args: any[] = [] |
| 159 | + const transferables: Transferable[] = [] |
| 160 | + |
| 161 | + for (const arg of rawArgs) { |
| 162 | + if (isTransferDescriptor(arg)) { |
| 163 | + args.push(serializer.serialize(arg.send)) |
| 164 | + transferables.push(...arg.transferables) |
| 165 | + } else { |
| 166 | + args.push(serializer.serialize(arg)) |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | + return { |
| 171 | + args, |
| 172 | + transferables: transferables.length === 0 ? transferables : dedupe(transferables) |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +export function createProxyFunction<Args extends any[], ReturnType>( |
| 177 | + relay: MessageRelay, |
| 178 | + serializer: Serializer, |
| 179 | + fid: number, |
| 180 | + debug: Debugger |
| 181 | +) { |
| 182 | + return ((...rawArgs: Args) => { |
| 183 | + const uid = nextCallID++ |
| 184 | + const { args, transferables } = prepareArguments(serializer, rawArgs) |
| 185 | + const runMessage: CallInvocationMessage = { |
| 186 | + type: CommonMessageType.invoke, |
| 187 | + fid, |
| 188 | + uid, |
| 189 | + args |
| 190 | + } |
| 191 | + |
| 192 | + debug("Sending command to run function to worker:", runMessage) |
| 193 | + |
| 194 | + try { |
| 195 | + relay.postMessage(runMessage, transferables) |
| 196 | + } catch (error) { |
| 197 | + return ObservablePromise.from(Promise.reject(error)) |
| 198 | + } |
| 199 | + |
| 200 | + return ObservablePromise.from(multicast(createObservableForJob<ReturnType>(relay, serializer, uid, debug))) |
| 201 | + }) as any as ProxyableFunction<Args, ReturnType> |
| 202 | +} |
| 203 | + |
| 204 | +export function createProxyModule<Methods extends ModuleMethods>( |
| 205 | + relay: MessageRelay, |
| 206 | + serializer: Serializer, |
| 207 | + methods: Record<string, number>, |
| 208 | + debug: Debugger |
| 209 | +): ModuleProxy<Methods> { |
| 210 | + const proxy: any = {} |
| 211 | + |
| 212 | + for (const methodName of Object.keys(methods)) { |
| 213 | + proxy[methodName] = createProxyFunction(relay, serializer, methods[methodName], debug) |
| 214 | + } |
| 215 | + |
| 216 | + return proxy |
| 217 | +} |
| 218 | + |
| 219 | +async function invokeExposedLocalFunction( |
| 220 | + relay: MessageRelay, |
| 221 | + serializer: Serializer, |
| 222 | + callback: Callback, |
| 223 | + message: CallInvocationMessage |
| 224 | +) { |
| 225 | + let syncResult: any |
| 226 | + const uid = message.uid |
| 227 | + |
| 228 | + try { |
| 229 | + const args = message.args.map(arg => serializer.deserialize(arg, relay)) |
| 230 | + syncResult = callback(...args) |
| 231 | + } catch (error) { |
| 232 | + postCallError(relay, uid, serializer.serialize(error) as any as SerializedError) |
| 233 | + } |
| 234 | + |
| 235 | + const resultType = isObservable(syncResult) ? "observable" : "promise" |
| 236 | + postCallRunning(relay, uid, resultType) |
| 237 | + |
| 238 | + if (isObservable(syncResult)) { |
| 239 | + const subscription = syncResult.subscribe( |
| 240 | + value => postCallResult(relay, uid, false, serializer.serialize(value)), |
| 241 | + error => postCallError(relay, uid, serializer.serialize(error) as any), |
| 242 | + () => postCallResult(relay, uid, true) |
| 243 | + ) |
| 244 | + activeSubscriptions.set(uid, subscription) |
| 245 | + } else { |
| 246 | + try { |
| 247 | + const result = await syncResult |
| 248 | + postCallResult(relay, uid, true, serializer.serialize(result)) |
| 249 | + } catch (error) { |
| 250 | + postCallError(relay, uid, serializer.serialize(error) as any) |
| 251 | + } |
| 252 | + } |
| 253 | +} |
| 254 | + |
| 255 | +function handleRemoteInvocation( |
| 256 | + relay: MessageRelay, |
| 257 | + serializer: Serializer, |
| 258 | + message: CallInvocationMessage, |
| 259 | + debug: Debugger |
| 260 | +) { |
| 261 | + const callback = lookupLocalCallback(message.fid) |
| 262 | + |
| 263 | + if (!callback) { |
| 264 | + debug(`Call to exposed local function failed: Function not found: UID ${message.uid}`) |
| 265 | + return postCallError(relay, message.uid, serializer.serialize(Error(`Function not found: UID ${message.uid}`)) as any as SerializedError) |
| 266 | + } |
| 267 | + |
| 268 | + debug(`Received invocation of local exposed function ${message.fid}, call UID ${message.uid} with arguments:`, message.args) |
| 269 | + return invokeExposedLocalFunction(relay, serializer, callback, message) |
| 270 | +} |
| 271 | + |
| 272 | +export function handleFunctionInvocations(relay: MessageRelay, serializer: Serializer, debug: Debugger) { |
| 273 | + relay.addEventListener("message", (event: MessageEvent) => { |
| 274 | + debug(`Received message:`, event.data) |
| 275 | + |
| 276 | + if (isInvocationMessage(event.data)) { |
| 277 | + handleRemoteInvocation(relay, serializer, event.data, debug) |
| 278 | + } |
| 279 | + }) |
| 280 | +} |
| 281 | + |
| 282 | +export function handleCallCancellations(relay: MessageRelay, debug: Debugger) { |
| 283 | + relay.addEventListener("message", event => { |
| 284 | + const messageData = event.data |
| 285 | + |
| 286 | + if (isCallCancelMessage(messageData)) { |
| 287 | + const jobUID = messageData.uid |
| 288 | + const subscription = activeSubscriptions.get(jobUID) |
| 289 | + |
| 290 | + if (subscription) { |
| 291 | + subscription.unsubscribe() |
| 292 | + activeSubscriptions.delete(jobUID) |
| 293 | + } |
| 294 | + } |
| 295 | + }) |
| 296 | +} |
0 commit comments