diff --git a/crates/base/test_cases/supabase-ai/index.ts b/crates/base/test_cases/supabase-ai/index.ts index 51d05b87..97e32d6b 100644 --- a/crates/base/test_cases/supabase-ai/index.ts +++ b/crates/base/test_cases/supabase-ai/index.ts @@ -1,7 +1,19 @@ -import { assertGreater, assertLessOrEqual } from "jsr:@std/assert"; +import { + assertEquals, + assertExists, + assertGreater, + assertIsError, + assertLessOrEqual, + assertStringIncludes, + assertThrows, +} from "jsr:@std/assert"; const session = new Supabase.ai.Session("gte-small"); +assertThrows(() => { + const _ = new Supabase.ai.Session("gte-small_wrong_name"); +}, "invalid 'Session' type"); + function dotProduct(a: number[], b: number[]) { let result = 0; for (let i = 0; i < a.length; i++) { @@ -15,27 +27,54 @@ export default { async fetch() { // Generate embedding // @ts-ignore unkwnow type - const meow: number[] = await session.run("meow", { - mean_pool: true, - normalize: true, - }); + const [meow, meowError] = await session.run("meow") as [ + number[], + undefined, + ]; // @ts-ignore unkwnow type - const love: number[] = await session.run("I love cats", { + const [love, loveError] = await session.run("I love cats", { mean_pool: true, normalize: true, - }); + }) as [number[], undefined]; + + // "Valid input should result in ok value" + { + assertExists(meow); + assertExists(love); + + assertEquals(meowError, undefined); + assertEquals(loveError, undefined); + } + + // "Invalid input should result in error value" + { + const [notCat, notCatError] = await session.run({ + bad_input: { "not a cat": "let fail" }, + }) as [undefined, { message: string; inner: Error }]; + + assertEquals(notCat, undefined); + + assertExists(notCatError); + assertIsError(notCatError.inner); + assertStringIncludes( + notCatError.message, + "must provide a valid prompt value", + ); + } - // Ensures `mean_pool` and `normalize` - const sameScore = dotProduct(meow, meow); - const diffScore = dotProduct(meow, love); + // "Ensures `mean_pool` and `normalize`" + { + const sameScore = dotProduct(meow, meow); + const diffScore = dotProduct(meow, love); - assertGreater(sameScore, 0.9); - assertGreater(diffScore, 0.5); - assertGreater(sameScore, diffScore); + assertGreater(sameScore, 0.9); + assertGreater(diffScore, 0.5); + assertGreater(sameScore, diffScore); - assertLessOrEqual(sameScore, 1); - assertLessOrEqual(diffScore, 1); + assertLessOrEqual(sameScore, 1); + assertLessOrEqual(diffScore, 1); + } return new Response( null, diff --git a/ext/ai/js/ai.d.ts b/ext/ai/js/ai.d.ts new file mode 100644 index 00000000..81a2ec00 --- /dev/null +++ b/ext/ai/js/ai.d.ts @@ -0,0 +1,21 @@ +import { Session } from "./ai.ts"; +import { LLMSessionRunInputOptions } from "./llm/llm_session.ts"; +import { + OllamaProviderInput, + OllamaProviderOptions, +} from "./llm/providers/ollama.ts"; +import { + OpenAIProviderInput, + OpenAIProviderOptions, +} from "./llm/providers/openai.ts"; + +export namespace ai { + export { Session }; + export { + LLMSessionRunInputOptions as LLMRunOptions, + OllamaProviderInput as OllamaInput, + OllamaProviderOptions as OllamaOptions, + OpenAIProviderInput as OpenAICompatibleInput, + OpenAIProviderOptions as OpenAICompatibleOptions, + }; +} diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js deleted file mode 100644 index c2a30692..00000000 --- a/ext/ai/js/ai.js +++ /dev/null @@ -1,263 +0,0 @@ -import "ext:ai/onnxruntime/onnx.js"; -import EventSourceStream from "ext:ai/util/event_source_stream.mjs"; - -const core = globalThis.Deno.core; - -/** - * @param {ReadableStream p !== "")) { - try { - yield JSON.parse(part); - } catch (error) { - yield { error }; - } - } -}; - -/** - * @param {ReadableStream} */ - const reader = decoder.readable.getReader(); - - while (true) { - try { - if (signal.aborted) { - reader.cancel(signal.reason); - reader.releaseLock(); - return { error: signal.reason }; - } - - const { done, value } = await reader.read(); - - if (done) { - break; - } - - yield JSON.parse(value.data); - } catch (error) { - yield { error }; - } - } -}; - -class Session { - model; - init; - is_ext_inference_api; - inferenceAPIHost; - - constructor(model) { - this.model = model; - this.is_ext_inference_api = false; - - if (model === "gte-small") { - this.init = core.ops.op_ai_init_model(model); - } else { - this.inferenceAPIHost = core.ops.op_get_env("AI_INFERENCE_API_HOST"); - this.is_ext_inference_api = !!this.inferenceAPIHost; // only enable external inference API if env variable is set - } - } - - /** @param {string | object} prompt Either a String (ollama) or an OpenAI chat completion body object (openaicompatible): https://platform.openai.com/docs/api-reference/chat/create */ - async run(prompt, opts = {}) { - if (this.is_ext_inference_api) { - const stream = opts.stream ?? false; - - // default timeout 60s - const timeout = typeof opts.timeout === "number" ? opts.timeout : 60; - const timeoutMs = timeout * 1000; - - /** @type {'ollama' | 'openaicompatible'} */ - const mode = opts.mode ?? "ollama"; - - switch (mode) { - case "ollama": - case "openaicompatible": - break; - - default: - throw new TypeError(`invalid mode: ${mode}`); - } - - const timeoutSignal = AbortSignal.timeout(timeoutMs); - const signals = [opts.signal, timeoutSignal] - .filter((it) => it instanceof AbortSignal); - - const signal = AbortSignal.any(signals); - - const path = mode === "ollama" ? "/api/generate" : "/v1/chat/completions"; - const body = mode === "ollama" ? { prompt } : prompt; - - const res = await fetch( - new URL(path, this.inferenceAPIHost), - { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: this.model, - stream, - ...body, - }), - }, - { signal }, - ); - - if (!res.ok) { - throw new Error( - `Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`, - ); - } - - if (!res.body) { - throw new Error("Missing body"); - } - - const parseGenFn = mode === "ollama" - ? parseJSON - : stream === true - ? parseJSONOverEventStream - : parseJSON; - const itr = parseGenFn(res.body, signal); - - if (stream) { - return (async function* () { - for await (const message of itr) { - if ("error" in message) { - if (message.error instanceof Error) { - throw message.error; - } else { - throw new Error(message.error); - } - } - - yield message; - - switch (mode) { - case "ollama": { - if (message.done) { - return; - } - - break; - } - - case "openaicompatible": { - const finishReason = message.choices[0].finish_reason; - - if (finishReason) { - if (finishReason !== "stop") { - throw new Error("Expected a completed response."); - } - - return; - } - - break; - } - - default: - throw new Error("unreachable"); - } - } - - throw new Error( - "Did not receive done or success response in stream.", - ); - })(); - } else { - const message = await itr.next(); - - if (message.value && "error" in message.value) { - const error = message.value.error; - - if (error instanceof Error) { - throw error; - } else { - throw new Error(error); - } - } - - const finish = mode === "ollama" - ? message.value.done - : message.value.choices[0].finish_reason === "stop"; - - if (finish !== true) { - throw new Error("Expected a completed response."); - } - - return message.value; - } - } - - if (this.init) { - await this.init; - } - - const mean_pool = opts.mean_pool ?? true; - const normalize = opts.normalize ?? true; - const result = await core.ops.op_ai_run_model( - this.model, - prompt, - mean_pool, - normalize, - ); - - return result; - } -} - -const MAIN_WORKER_API = { - tryCleanupUnusedSession: () => - /* async */ core.ops.op_ai_try_cleanup_unused_session(), -}; - -const USER_WORKER_API = { - Session, -}; - -export { MAIN_WORKER_API, USER_WORKER_API }; diff --git a/ext/ai/js/ai.ts b/ext/ai/js/ai.ts new file mode 100644 index 00000000..23ad212d --- /dev/null +++ b/ext/ai/js/ai.ts @@ -0,0 +1,182 @@ +import "./onnxruntime/onnx.js"; +import { + LLMProviderInstance, + LLMProviderName, + LLMSession, + LLMSessionRunInputOptions as LLMInputOptions, + providers, +} from "./llm/llm_session.ts"; + +// @ts-ignore deno_core environment +const core = globalThis.Deno.core; + +// TODO: extract to utils file +export type Result = [T, undefined] | [undefined, E]; + +// NOTE:(kallebysantos) do we still need gte-small? Or maybe add another type 'embeddings' with custom model opt. +export type SessionType = LLMProviderName | "gte-small"; + +export type SessionOptions = T extends LLMProviderName + ? LLMProviderInstance["options"] + : never; + +export type SessionInput = T extends LLMProviderName + ? LLMProviderInstance["input"] + : T extends "gte-small" ? string + : never; + +export type EmbeddingInputOptions = { + /** + * Pool embeddings by taking their mean + */ + mean_pool?: boolean; + + /** + * Normalize the embeddings result + */ + normalize?: boolean; +}; + +export type SessionInputOptions = T extends "gte-small" + ? EmbeddingInputOptions + : T extends LLMProviderName ? LLMInputOptions + : never; + +export type SessionOutput = T extends "gte-small" + ? number[] + : T extends LLMProviderName ? O extends { stream: true } ? AsyncGenerator< + Result< + LLMProviderInstance["output"], + LLMProviderInstance["error"] + > + > + : LLMProviderInstance["output"] + : never; + +export type SessionError = { + message: string; + inner: T; +}; + +export type SessionOutputError = T extends "gte-small" + ? SessionError + : T extends LLMProviderName ? SessionError["error"]> + : any; + +export class Session { + #model?: string; + #init?: Promise; + + constructor( + public readonly type: T, + public readonly options?: SessionOptions, + ) { + if (this.isEmbeddingType()) { + this.#model = "gte-small"; // Default model + this.#init = core.ops.op_ai_init_model(this.#model); + return; + } + + if (this.isLLMType()) { + if (!Object.keys(providers).includes(type)) { + throw new TypeError(`invalid type: '${type}'`); + } + + if (!this.options || !this.options.model) { + throw new Error( + `missing required parameter 'model' for type: '${type}'`, + ); + } + + this.options.baseURL ??= core.ops.op_get_env( + "AI_INFERENCE_API_HOST", + ) as string; + + if (!this.options.baseURL) { + throw new Error( + `missing required parameter 'baseURL' for type: '${type}'`, + ); + } + } + } + + async run>( + input: SessionInput, + options?: O, + ): Promise< + [SessionOutput, undefined] | [undefined, SessionOutputError] + > { + try { + if (this.isLLMType()) { + const opts = options as LLMInputOptions; + const stream = opts.stream ?? false; + + const llmSession = LLMSession.fromProvider(this.type, { + // safety: We did check `options` during construction + baseURL: this.options!.baseURL, + model: this.options!.model, + ...this.options, // allows custom provider initialization like 'apiKey' + }); + + const [output, error] = await llmSession.run(input, { + stream, + signal: opts.signal, + timeout: opts.timeout, + }); + if (error) { + return [undefined, error as SessionOutputError]; + } + + return [output as SessionOutput, undefined]; + } + + if (this.#init) { + await this.#init; + } + + const opts = options as EmbeddingInputOptions | undefined; + + const mean_pool = opts?.mean_pool ?? true; + const normalize = opts?.normalize ?? true; + + const result = await core.ops.op_ai_run_model( + this.#model, + input, + mean_pool, + normalize, + ) as SessionOutput; + + return [result, undefined]; + } catch (e: any) { + const error = (e instanceof Error) ? e : new Error(e); + + return [ + undefined, + { inner: error, message: error.message } as SessionOutputError, + ]; + } + } + + private isEmbeddingType( + this: Session, + ): this is Session<"gte-small"> { + return this.type === "gte-small"; + } + + private isLLMType( + this: Session, + ): this is Session { + return this.type !== "gte-small"; + } +} + +const MAIN_WORKER_API = { + tryCleanupUnusedSession: () => + /* async */ core.ops.op_ai_try_cleanup_unused_session(), +}; + +const USER_WORKER_API = { + Session, +}; + +export { MAIN_WORKER_API, USER_WORKER_API }; diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts new file mode 100644 index 00000000..b9db4e15 --- /dev/null +++ b/ext/ai/js/llm/llm_session.ts @@ -0,0 +1,142 @@ +import { Result, SessionError } from "../ai.ts"; +import { OllamaLLMSession } from "./providers/ollama.ts"; +import { OpenAILLMSession } from "./providers/openai.ts"; + +export type LLMRunInput = { + /** + * Stream response from model. Applies only for LLMs like `mistral` (default: false) + */ + stream?: boolean; + + /** + * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) + */ + timeout?: number; + + prompt: string; + + signal?: AbortSignal; +}; + +export interface ILLMProviderMeta { + input: ILLMProviderInput; + output: unknown; + error: unknown; + options: ILLMProviderOptions; +} + +export interface ILLMProviderOptions { + model: string; + baseURL?: string; +} + +export type ILLMProviderInput = T extends string ? string + : T; + +export interface ILLMProviderOutput { + value?: string; + usage: { + inputTokens: number; + outputTokens: number; + totalTokens: number; + }; + inner: T; +} + +export interface ILLMProviderError extends SessionError { +} + +export interface ILLMProvider { + getStream( + input: ILLMProviderInput, + signal: AbortSignal, + ): Promise< + Result< + AsyncIterable>, + ILLMProviderError + > + >; + getText( + input: ILLMProviderInput, + signal: AbortSignal, + ): Promise>; +} + +export const providers = { + "ollama": OllamaLLMSession, + "openaicompatible": OpenAILLMSession, +} satisfies Record< + string, + new (opts: ILLMProviderOptions) => ILLMProvider & ILLMProviderMeta +>; + +export type LLMProviderName = keyof typeof providers; + +export type LLMProviderClass = (typeof providers)[T]; +export type LLMProviderInstance = InstanceType< + LLMProviderClass +>; + +export type LLMSessionRunInputOptions = { + /** + * Stream response from model. Applies only for LLMs like `mistral` (default: false) + */ + stream?: boolean; + + /** + * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) + */ + timeout?: number; + + signal?: AbortSignal; +}; + +export type LLMSessionOutput = + | AsyncIterable> + | ILLMProviderOutput; + +export class LLMSession { + #inner: ILLMProvider; + + constructor(provider: ILLMProvider) { + this.#inner = provider; + } + + static fromProvider(name: LLMProviderName, opts: ILLMProviderOptions) { + const ProviderType = providers[name]; + if (!ProviderType) throw new Error("invalid provider"); + + const provider = new ProviderType(opts); + + return new LLMSession(provider); + } + + async run( + input: ILLMProviderInput, + opts: LLMSessionRunInputOptions, + ): Promise> { + const isStream = opts.stream ?? false; + + const timeoutSeconds = typeof opts.timeout === "number" ? opts.timeout : 60; + const timeoutMs = timeoutSeconds * 1000; + + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const abortSignals = [opts.signal, timeoutSignal] + .filter((it) => it instanceof AbortSignal); + const signal = AbortSignal.any(abortSignals); + + if (isStream) { + const [stream, getStreamError] = await this.#inner.getStream( + input, + signal, + ); + if (getStreamError) { + return [undefined, getStreamError]; + } + + return [stream, undefined]; + } + + return this.#inner.getText(input, signal); + } +} diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts new file mode 100644 index 00000000..0956f1cd --- /dev/null +++ b/ext/ai/js/llm/providers/ollama.ts @@ -0,0 +1,207 @@ +import { Result } from "../../ai.ts"; +import { + ILLMProvider, + ILLMProviderError, + ILLMProviderInput, + ILLMProviderMeta, + ILLMProviderOptions, + ILLMProviderOutput, +} from "../llm_session.ts"; +import { parseJSON } from "../utils/json_parser.ts"; + +export type OllamaProviderOptions = ILLMProviderOptions; +export type OllamaProviderInput = ILLMProviderInput; +export type OllamaProviderOutput = Result< + ILLMProviderOutput, + OllamaProviderError +>; +export type OllamaProviderError = ILLMProviderError; + +export type OllamaMessage = { + model: string; + created_at: Date; + response: string; + done: boolean; + context: number[]; + total_duration: number; + load_duration: number; + prompt_eval_count: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; +}; + +export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { + input!: OllamaProviderInput; + output!: ILLMProviderOutput; + error!: OllamaProviderError; + options: OllamaProviderOptions; + + constructor(opts: OllamaProviderOptions) { + this.options = opts; + } + + // ref: https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L26 + async getStream( + prompt: OllamaProviderInput, + signal: AbortSignal, + ): Promise< + Result, OllamaProviderError> + > { + const [generator, error] = await this.generate( + prompt, + signal, + true, + ) as Result, OllamaProviderError>; + + if (error) { + return [undefined, error]; + } + + // NOTE:(kallebysantos) we need to clone the lambda parser to avoid `this` conflicts inside the local function* + const parser = this.parse; + const stream = async function* () { + for await (const message of generator) { + if ("error" in message) { + const error = (message.error instanceof Error) + ? message.error + : new Error(message.error as string); + + yield [ + undefined, + { + inner: { + error, + currentValue: null, + }, + message: "An unknown error was streamed from the provider.", + } satisfies OllamaProviderError, + ]; + } + + yield [parser(message), undefined]; + + if (message.done) { + return; + } + } + + throw new Error( + "Did not receive done or success response in stream.", + ); + }; + + return [ + stream() as AsyncIterable, + undefined, + ]; + } + + async getText( + prompt: OllamaProviderInput, + signal: AbortSignal, + ): Promise { + const [generation, generationError] = await this.generate( + prompt, + signal, + ) as Result; + + if (generationError) { + return [undefined, generationError]; + } + + if (!generation?.done) { + return [undefined, { + inner: { + error: new Error("Expected a completed response."), + currentValue: generation, + }, + message: + `Response could not be completed successfully. Expected 'done'`, + }]; + } + + return [this.parse(generation), undefined]; + } + + private parse(message: OllamaMessage): ILLMProviderOutput { + const { response, prompt_eval_count, eval_count } = message; + + const inputTokens = isNaN(prompt_eval_count) ? 0 : prompt_eval_count; + const outputTokens = isNaN(eval_count) ? 0 : eval_count; + + return { + value: response, + inner: message, + usage: { + inputTokens, + outputTokens, + totalTokens: inputTokens + outputTokens, + }, + }; + } + + private async generate( + prompt: string, + signal: AbortSignal, + stream: boolean = false, + ): Promise< + Result | OllamaMessage, OllamaProviderError> + > { + const res = await fetch( + new URL("/api/generate", this.options.baseURL), + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: this.options.model, + stream, + prompt, + }), + signal, + }, + ); + + // try to extract the json error otherwise return any text content from the response + if (!res.ok || !res.body) { + const errorMsg = + `Failed to fetch inference API host '${this.options.baseURL}'. Status ${res.status}: ${res.statusText}`; + + if (!res.body) { + const error = { + inner: new Error("Missing response body."), + message: errorMsg, + } satisfies OllamaProviderError; + + return [undefined, error]; + } + + // safe to extract response body cause it was checked above + try { + const error = { + inner: await res.json(), + message: errorMsg, + } satisfies OllamaProviderError; + + return [undefined, error]; + } catch (_) { + const error = { + inner: new Error(await res.text()), + message: errorMsg, + } satisfies OllamaProviderError; + + return [undefined, error]; + } + } + + if (stream) { + const stream = parseJSON(res.body, signal); + return [stream as AsyncGenerator, undefined]; + } + + const result: OllamaMessage = await res.json(); + return [result, undefined]; + } +} diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts new file mode 100644 index 00000000..e085f934 --- /dev/null +++ b/ext/ai/js/llm/providers/openai.ts @@ -0,0 +1,315 @@ +import { Result } from "../../ai.ts"; +import { + ILLMProvider, + ILLMProviderError, + ILLMProviderInput, + ILLMProviderMeta, + ILLMProviderOptions, + ILLMProviderOutput, +} from "../llm_session.ts"; +import { parseJSONOverEventStream } from "../utils/json_parser.ts"; + +export type OpenAIProviderOptions = ILLMProviderOptions & { + apiKey?: string; +}; + +// NOTE:(kallebysantos) we define all types here for better development as well avoid `"npm:openai"` import +// TODO:(kallebysantos) need to double check theses AI generated types +export type OpenAIRequest = { + model: string; + messages: { + // NOTE:(kallebysantos) using role as union type is great for intellisense suggestions + // but at same time it forces users to `{} satisfies Supabase.ai.OpenAICompatibleInput` + role: "system" | "user" | "assistant" | "tool"; + content: string; + name?: string; + tool_call_id?: string; + function_call?: { + name: string; + arguments: string; + }; + }[]; + temperature?: number; + top_p?: number; + n?: number; + stream?: boolean; + stream_options: { + include_usage: boolean; + }; + stop?: string | string[]; + max_tokens?: number; + presence_penalty?: number; + frequency_penalty?: number; + logit_bias?: { [token: string]: number }; + user?: string; + tools?: { + type: "function"; + function: { + name: string; + description?: string; + parameters: unknown; + }; + }[]; + tool_choice?: "none" | "auto" | { + type: "function"; + function: { name: string }; + }; +}; + +export type OpenAIResponseUsage = { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + prompt_tokens_details: { + cached_tokens: number; + audio_tokens: number; + }; + completion_tokens_details: { + reasoning_tokens: number; + audio_tokens: number; + accepted_prediction_tokens: number; + rejected_prediction_tokens: number; + }; +}; + +export type OpenAIResponseChoice = { + index: number; + message?: { + role: "assistant" | "user" | "system" | "tool"; + content: string | null; + function_call?: { + name: string; + arguments: string; + }; + tool_calls?: { + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; + }[]; + }; + delta?: { + content: string | null; + }; + finish_reason: "stop" | "length" | "tool_calls" | "content_filter" | null; +}; + +export type OpenAIResponse = { + id: string; + object: "chat.completion"; + created: number; + model: string; + system_fingerprint?: string; + choices: OpenAIResponseChoice[]; + usage?: OpenAIResponseUsage; +}; + +export type OpenAICompatibleInput = Omit< + OpenAIRequest, + "stream" | "stream_options" | "model" +>; + +export type OpenAIProviderInput = ILLMProviderInput; +export type OpenAIProviderOutput = Result< + ILLMProviderOutput, + OpenAIProviderError +>; +export type OpenAIProviderError = ILLMProviderError; + +export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { + input!: OpenAIProviderInput; + output!: ILLMProviderOutput; + error!: OpenAIProviderError; + options: OpenAIProviderOptions; + + constructor(opts: OpenAIProviderOptions) { + this.options = opts; + } + + async getStream( + prompt: OpenAIProviderInput, + signal: AbortSignal, + ): Promise< + Result, OpenAIProviderError> + > { + const [generator, error] = await this.generate( + prompt, + signal, + true, + ) as Result, OpenAIProviderError>; + + if (error) { + return [undefined, error]; + } + + // NOTE:(kallebysantos) we need to clone the lambda parser to avoid `this` conflicts inside the local function* + const parser = this.parse; + const stream = async function* () { + for await (const message of generator) { + // NOTE:(kallebysantos) while streaming the final message will not include 'finish_reason' + // Instead a '[DONE]' value will be returned to close the stream + if ("done" in message && message.done) { + return; + } + + if ("error" in message) { + const error = (message.error instanceof Error) + ? message.error + : new Error(message.error as string); + + yield [ + undefined, + { + inner: { + error, + currentValue: null, + }, + message: "An unknown error was streamed from the provider.", + } satisfies OpenAIProviderError, + ]; + } + + yield [parser(message), undefined]; + + const finishReason = message.choices.at(0)?.finish_reason; + if (finishReason && finishReason !== "stop") { + yield [undefined, { + inner: { + error: new Error("Expected a completed response."), + currentValue: message, + }, + message: + `Response could not be completed successfully. Expected 'stop' finish reason got '${finishReason}'`, + }]; + } + } + + throw new Error( + "Did not receive done or success response in stream.", + ); + }; + + return [ + stream() as AsyncIterable, + undefined, + ]; + } + + async getText( + prompt: OpenAIProviderInput, + signal: AbortSignal, + ): Promise { + const [generation, generationError] = await this.generate( + prompt, + signal, + ) as Result; + + if (generationError) { + return [undefined, generationError]; + } + + const finishReason = generation.choices[0].finish_reason; + + if (finishReason !== "stop") { + return [undefined, { + inner: { + error: new Error("Expected a completed response."), + currentValue: generation, + }, + message: + `Response could not be completed successfully. Expected 'stop' finish reason got '${finishReason}'`, + }]; + } + + return [this.parse(generation), undefined]; + } + + private parse(response: OpenAIResponse): ILLMProviderOutput { + const { usage } = response; + const choice = response.choices.at(0); + + return { + // NOTE:(kallebysantos) while streaming the 'delta' field will be used instead of 'message' + value: choice?.message?.content ?? choice?.delta?.content ?? undefined, + inner: response, + usage: { + // NOTE:(kallebysantos) usage maybe 'null' while streaming, but the final message will include it + inputTokens: usage?.prompt_tokens ?? 0, + outputTokens: usage?.completion_tokens ?? 0, + totalTokens: usage?.total_tokens ?? 0, + }, + }; + } + + private async generate( + input: OpenAICompatibleInput, + signal: AbortSignal, + stream: boolean = false, + ): Promise< + Result | OpenAIResponse, OpenAIProviderError> + > { + const res = await fetch( + new URL("/v1/chat/completions", this.options.baseURL), + { + method: "POST", + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${this.options.apiKey}`, + }, + body: JSON.stringify( + { + ...input, + model: this.options.model, + stream, + stream_options: { + include_usage: true, + }, + } satisfies OpenAIRequest, + ), + signal, + }, + ); + + // try to extract the json error otherwise return any text content from the response + if (!res.ok || !res.body) { + const errorMsg = + `Failed to fetch inference API host '${this.options.baseURL}'. Status ${res.status}: ${res.statusText}`; + + if (!res.body) { + const error = { + inner: new Error("Missing response body."), + message: errorMsg, + } satisfies OpenAIProviderError; + + return [undefined, error]; + } + + // safe to extract response body cause it was checked above + try { + const error = { + inner: await res.json(), + message: errorMsg, + } satisfies OpenAIProviderError; + + return [undefined, error]; + } catch (_) { + const error = { + inner: new Error(await res.text()), + message: errorMsg, + } satisfies OpenAIProviderError; + + return [undefined, error]; + } + } + + if (stream) { + const stream = parseJSONOverEventStream(res.body, signal); + return [stream as AsyncGenerator, undefined]; + } + + const result: OpenAIResponse = await res.json(); + return [result, undefined]; + } +} diff --git a/ext/ai/js/llm/utils/event_source_stream.mjs b/ext/ai/js/llm/utils/event_source_stream.mjs new file mode 100644 index 00000000..0ec0b889 --- /dev/null +++ b/ext/ai/js/llm/utils/event_source_stream.mjs @@ -0,0 +1,41 @@ +import EventStreamParser from "./event_stream_parser.mjs"; +/** + * A Web stream which handles Server-Sent Events from a binary ReadableStream like you get from the fetch API. + * Implements the TransformStream interface, and can be used with the Streams API as such. + */ +class EventSourceStream { + constructor() { + // Two important things to note here: + // 1. The SSE spec allows for an optional UTF-8 BOM. + // 2. We have to use a *streaming* decoder, in case two adjacent data chunks are split up in the middle of a + // multibyte Unicode character. Trying to parse the two separately would result in data corruption. + const decoder = new TextDecoderStream("utf-8"); + let parser; + const sseStream = new TransformStream({ + start(controller) { + parser = new EventStreamParser((data, eventType, lastEventId) => { + // NOTE:(kallebysantos) Some providers like OpenAI send '[DONE]' + // to indicates stream terminates, so we need to check if the SSE contains "[DONE]" and close the stream + if (typeof data === "string" && data.trim() === "[DONE]") { + controller.terminate?.(); // If supported + controller.close?.(); // Fallback + return; + } + + controller.enqueue( + new MessageEvent(eventType, { data, lastEventId }), + ); + }); + }, + transform(chunk) { + parser.push(chunk); + }, + }); + + decoder.readable.pipeThrough(sseStream); + + this.readable = sseStream.readable; + this.writable = decoder.writable; + } +} +export default EventSourceStream; diff --git a/ext/ai/js/llm/utils/event_stream_parser.mjs b/ext/ai/js/llm/utils/event_stream_parser.mjs new file mode 100644 index 00000000..263229a6 --- /dev/null +++ b/ext/ai/js/llm/utils/event_stream_parser.mjs @@ -0,0 +1,92 @@ +// https://github.com/valadaptive/server-sent-stream + +/** + * A parser for the server-sent events stream format. + * + * Note that this parser does not handle text decoding! To do it correctly, use a streaming text decoder, since the + * stream could be split up mid-Unicode character, and decoding each chunk at once could lead to incorrect results. + * + * This parser is used by streaming chunks in using the {@link push} method, and then calling the {@link end} method + * when the stream has ended. + */ +class EventStreamParser { + /** + * Construct a new parser for a single stream. + * @param onEvent A callback which will be called for each new event parsed. The parameters in order are the + * event data, the event type, and the last seen event ID. This may be called none, once, or many times per push() + * call, and may be called from the end() call. + */ + constructor(onEvent) { + this.streamBuffer = ""; + this.lastEventId = ""; + this.onEvent = onEvent; + } + /** + * Process a single incoming chunk of the event stream. + */ + _processChunk() { + // Events are separated by two newlines + const events = this.streamBuffer.split(/\r\n\r\n|\r\r|\n\n/g); + if (events.length === 0) { + return; + } + // The leftover text to remain in the buffer is whatever doesn't have two newlines after it. If the buffer ended + // with two newlines, this will be an empty string. + this.streamBuffer = events.pop(); + for (const eventChunk of events) { + let eventType = ""; + // Split up by single newlines. + const lines = eventChunk.split(/\n|\r|\r\n/g); + let eventData = ""; + for (const line of lines) { + const lineMatch = /([^:]+)(?:: ?(.*))?/.exec(line); + if (lineMatch) { + const field = lineMatch[1]; + const value = lineMatch[2] || ""; + switch (field) { + case "event": + eventType = value; + break; + case "data": + eventData += value; + eventData += "\n"; + break; + case "id": + // The ID field cannot contain null, per the spec + if (!value.includes("\0")) { + this.lastEventId = value; + } + break; + // We do nothing for the `delay` type, and other types are explicitly ignored + } + } + } + // https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage + // Skip the event if the data buffer is the empty string. + if (eventData === "") { + continue; + } + if (eventData[eventData.length - 1] === "\n") { + eventData = eventData.slice(0, -1); + } + // Trim the *last* trailing newline only. + this.onEvent(eventData, eventType || "message", this.lastEventId); + } + } + /** + * Push a new chunk of data to the parser. This may cause the {@link onEvent} callback to be called, possibly + * multiple times depending on the number of events contained within the chunk. + * @param chunk The incoming chunk of data. + */ + push(chunk) { + this.streamBuffer += chunk; + this._processChunk(); + } + /** + * Indicate that the stream has ended. + */ + end() { + // This is a no-op + } +} +export default EventStreamParser; diff --git a/ext/ai/js/llm/utils/json_parser.ts b/ext/ai/js/llm/utils/json_parser.ts new file mode 100644 index 00000000..2cb31866 --- /dev/null +++ b/ext/ai/js/llm/utils/json_parser.ts @@ -0,0 +1,83 @@ +import EventSourceStream from "./event_source_stream.mjs"; + +// Adapted from https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L262 +// TODO:(kallebysantos) need to simplify it +export async function* parseJSON( + itr: ReadableStream, + signal: AbortSignal, +) { + let buffer = ""; + + const decoder = new TextDecoder("utf-8"); + const reader = itr.getReader(); + + while (true) { + try { + if (signal.aborted) { + reader.cancel(signal.reason); + reader.releaseLock(); + return { error: signal.reason }; + } + + const { done, value } = await reader.read(); + + if (done) { + break; + } + + buffer += decoder.decode(value); + + const parts = buffer.split("\n"); + + buffer = parts.pop() ?? ""; + + for (const part of parts) { + yield JSON.parse(part) as T; + } + } catch (error) { + yield { error }; + } + } + + for (const part of buffer.split("\n").filter((p) => p !== "")) { + try { + yield JSON.parse(part) as T; + } catch (error) { + yield { error }; + } + } +} + +// TODO:(kallebysantos) need to simplify it +export async function* parseJSONOverEventStream( + itr: ReadableStream, + signal: AbortSignal, +) { + const decoder = new EventSourceStream(); + + itr.pipeThrough(decoder); + + const reader: ReadableStreamDefaultReader = decoder.readable + .getReader(); + + while (true) { + try { + if (signal.aborted) { + reader.cancel(signal.reason); + reader.releaseLock(); + return { error: signal.reason }; + } + + const { done, value } = await reader.read(); + + if (done) { + yield { done }; + break; + } + + yield JSON.parse(value.data) as T; + } catch (error) { + yield { error }; + } + } +} diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 224b0450..b58280a7 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -48,14 +48,18 @@ deno_core::extension!( op_ai_ort_init_session, op_ai_ort_run_session, ], - esm_entry_point = "ext:ai/ai.js", + esm_entry_point = "ext:ai/ai.ts", esm = [ dir "js", - "ai.js", - "util/event_stream_parser.mjs", - "util/event_source_stream.mjs", + "ai.ts", "onnxruntime/onnx.js", - "onnxruntime/cache_adapter.js" + "onnxruntime/cache_adapter.js", + "llm/llm_session.ts", + "llm/providers/ollama.ts", + "llm/providers/openai.ts", + "llm/utils/json_parser.ts", + "llm/utils/event_stream_parser.mjs", + "llm/utils/event_source_stream.mjs", ] ); @@ -276,6 +280,10 @@ async fn run_gte( mean_pool: bool, normalize: bool, ) -> Result, Error> { + if prompt.is_empty() { + bail!("must provide a valid prompt value, got 'empty'") + } + let req_tx; { let op_state = state.borrow(); diff --git a/ext/runtime/js/namespaces.js b/ext/runtime/js/namespaces.js index 367d8a02..d9112909 100644 --- a/ext/runtime/js/namespaces.js +++ b/ext/runtime/js/namespaces.js @@ -1,6 +1,6 @@ import { core, primordials } from "ext:core/mod.js"; -import { MAIN_WORKER_API, USER_WORKER_API } from "ext:ai/ai.js"; +import { MAIN_WORKER_API, USER_WORKER_API } from "ext:ai/ai.ts"; import { SUPABASE_USER_WORKERS } from "ext:user_workers/user_workers.js"; import { applySupabaseTag } from "ext:runtime/http.js"; import { waitUntil } from "ext:runtime/async_hook.js"; diff --git a/types/global.d.ts b/types/global.d.ts index 260f3b99..d28fea24 100644 --- a/types/global.d.ts +++ b/types/global.d.ts @@ -123,58 +123,65 @@ declare namespace EdgeRuntime { export { UserWorker as userWorkers }; } +// TODO:(kallebysantos) use some TS builder to bundle all types +import { ai as AINamespace } from "../ext/ai/js/ai.d.ts"; + declare namespace Supabase { - export namespace ai { - interface ModelOptions { - /** - * Pool embeddings by taking their mean. Applies only for `gte-small` model - */ - mean_pool?: boolean; - - /** - * Normalize the embeddings result. Applies only for `gte-small` model - */ - normalize?: boolean; - - /** - * Stream response from model. Applies only for LLMs like `mistral` (default: false) - */ - stream?: boolean; - - /** - * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) - */ - timeout?: number; - - /** - * Mode for the inference API host. (default: 'ollama') - */ - mode?: "ollama" | "openaicompatible"; - signal?: AbortSignal; - } - - export class Session { - /** - * Create a new model session using given model - */ - constructor(model: string); - - /** - * Execute the given prompt in model session - */ - run( - prompt: - | string - | Omit< - import("openai").OpenAI.Chat.ChatCompletionCreateParams, - "model" | "stream" - >, - modelOptions?: ModelOptions, - ): unknown; - } - } + export import ai = AINamespace; } +// declare namespace Supabase { +// export namespace ai { +// interface ModelOptions { +// /** +// * Pool embeddings by taking their mean. Applies only for `gte-small` model +// */ +// mean_pool?: boolean; +// +// /** +// * Normalize the embeddings result. Applies only for `gte-small` model +// */ +// normalize?: boolean; +// +// /** +// * Stream response from model. Applies only for LLMs like `mistral` (default: false) +// */ +// stream?: boolean; +// +// /** +// * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) +// */ +// timeout?: number; +// +// /** +// * Mode for the inference API host. (default: 'ollama') +// */ +// mode?: "ollama" | "openaicompatible"; +// signal?: AbortSignal; +// } +// +// export class Session { +// /** +// * Create a new model session using given model +// */ +// constructor(model: string); +// +// /** +// * Execute the given prompt in model session +// */ +// run( +// prompt: +// | string +// | Omit< +// import("openai").OpenAI.Chat.ChatCompletionCreateParams, +// "model" | "stream" +// >, +// modelOptions?: ModelOptions, +// ): unknown; +// } +// } +// } + declare namespace Deno { export namespace errors { class WorkerRequestCancelled extends Error {}