diff --git a/.github/matrix-commitly.yml b/.github/matrix-commitly.yml index 7685340597c..5e52cbc80f7 100644 --- a/.github/matrix-commitly.yml +++ b/.github/matrix-commitly.yml @@ -1,7 +1,7 @@ # please see matrix-full.yml for meaning of each field build-packages: - label: ubuntu-22.04 - os: ubuntu-22.04 + image: ubuntu:22.04 package: deb check-manifest-suite: ubuntu-22.04-amd64 diff --git a/.github/matrix-full.yml b/.github/matrix-full.yml index b011607f4c8..376fcac72ef 100644 --- a/.github/matrix-full.yml +++ b/.github/matrix-full.yml @@ -12,9 +12,11 @@ build-packages: package: deb check-manifest-suite: ubuntu-20.04-amd64 - label: ubuntu-22.04 + image: ubuntu:22.04 package: deb check-manifest-suite: ubuntu-22.04-amd64 - label: ubuntu-22.04-arm64 + image: ubuntu:22.04 package: deb bazel-args: --platforms=//:generic-crossbuild-aarch64 check-manifest-suite: ubuntu-22.04-arm64 diff --git a/changelog/unreleased/fix_hash.yml b/changelog/unreleased/fix_hash.yml new file mode 100644 index 00000000000..6c97221121d --- /dev/null +++ b/changelog/unreleased/fix_hash.yml @@ -0,0 +1,3 @@ +message: Fixed an inefficiency issue in the Luajit hashing algorithm +type: performance +scope: Performance diff --git a/changelog/unreleased/kong/ai-proxy-azure-streaming.yml b/changelog/unreleased/kong/ai-proxy-azure-streaming.yml new file mode 100644 index 00000000000..4b6f7c55669 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-azure-streaming.yml @@ -0,0 +1,5 @@ +message: | + **AI-proxy-plugin**: Fixed a bug where certain Azure models would return partial tokens/words + when in response-streaming mode. +scope: Plugin +type: bugfix diff --git a/changelog/unreleased/kong/ai-proxy-fix-model-parameter.yml b/changelog/unreleased/kong/ai-proxy-fix-model-parameter.yml new file mode 100644 index 00000000000..3727a02c4c2 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-fix-model-parameter.yml @@ -0,0 +1,5 @@ +message: | + **AI-proxy-plugin**: Fixed a bug where Cohere and Anthropic providers don't read the `model` parameter properly + from the caller's request body. +scope: Plugin +type: bugfix diff --git a/changelog/unreleased/kong/ai-proxy-fix-nil-response-token-count.yml b/changelog/unreleased/kong/ai-proxy-fix-nil-response-token-count.yml new file mode 100644 index 00000000000..f6681f7ec8b --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-fix-nil-response-token-count.yml @@ -0,0 +1,5 @@ +message: | + **AI-proxy-plugin**: Fixed a bug where using "OpenAI Function" inference requests would log a + request error, and then hang until timeout. +scope: Plugin +type: bugfix diff --git a/changelog/unreleased/kong/ai-proxy-fix-sending-own-model.yml b/changelog/unreleased/kong/ai-proxy-fix-sending-own-model.yml new file mode 100644 index 00000000000..fe432c71db5 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-fix-sending-own-model.yml @@ -0,0 +1,5 @@ +message: | + **AI-proxy-plugin**: Fixed a bug where AI Proxy would still allow callers to specify their own model, + ignoring the plugin-configured model name. +scope: Plugin +type: bugfix diff --git a/changelog/unreleased/kong/ai-proxy-fix-tuning-parameter-precedence.yml b/changelog/unreleased/kong/ai-proxy-fix-tuning-parameter-precedence.yml new file mode 100644 index 00000000000..9588b6d6f0e --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-fix-tuning-parameter-precedence.yml @@ -0,0 +1,5 @@ +message: | + **AI-proxy-plugin**: Fixed a bug where AI Proxy would not take precedence of the + plugin's configured model tuning options, over those in the user's LLM request. +scope: Plugin +type: bugfix diff --git a/changelog/unreleased/kong/ai-proxy-proper-model-assignment.yml b/changelog/unreleased/kong/ai-proxy-proper-model-assignment.yml new file mode 100644 index 00000000000..3f61e43f5b2 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-proper-model-assignment.yml @@ -0,0 +1,5 @@ +message: | + **AI-proxy-plugin**: Fixed a bug where setting OpenAI SDK model parameter "null" caused analytics + to not be written to the logging plugin(s). +scope: Plugin +type: bugfix diff --git a/changelog/unreleased/kong/fix-ai-proxy-shared-state.yml b/changelog/unreleased/kong/fix-ai-proxy-shared-state.yml new file mode 100644 index 00000000000..bb967a94656 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-shared-state.yml @@ -0,0 +1,3 @@ +message: "**AI-Proxy**: Resolved a bug where the object constructor would set data on the class instead of the instance" +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-realm-compat-changes-basic-auth.yml b/changelog/unreleased/kong/fix-realm-compat-changes-basic-auth.yml new file mode 100644 index 00000000000..6f2ce9d7bea --- /dev/null +++ b/changelog/unreleased/kong/fix-realm-compat-changes-basic-auth.yml @@ -0,0 +1,3 @@ +message: "**Basic-Auth**: Fix an issue of realm field not recognized for older kong versions (before 3.6)" +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-realm-compat-changes-key-auth.yml b/changelog/unreleased/kong/fix-realm-compat-changes-key-auth.yml new file mode 100644 index 00000000000..bb8d06a3146 --- /dev/null +++ b/changelog/unreleased/kong/fix-realm-compat-changes-key-auth.yml @@ -0,0 +1,3 @@ +message: "**Key-Auth**: Fix an issue of realm field not recognized for older kong versions (before 3.7)" +type: bugfix +scope: Plugin diff --git a/kong-3.7.1-0.rockspec b/kong-3.7.1-0.rockspec index c352825d5bd..6a35b313fc6 100644 --- a/kong-3.7.1-0.rockspec +++ b/kong-3.7.1-0.rockspec @@ -591,6 +591,7 @@ build = { ["kong.plugins.ai-response-transformer.schema"] = "kong/plugins/ai-response-transformer/schema.lua", ["kong.llm"] = "kong/llm/init.lua", + ["kong.llm.schemas"] = "kong/llm/schemas/init.lua", ["kong.llm.drivers.shared"] = "kong/llm/drivers/shared.lua", ["kong.llm.drivers.openai"] = "kong/llm/drivers/openai.lua", ["kong.llm.drivers.azure"] = "kong/llm/drivers/azure.lua", diff --git a/kong/clustering/compat/removed_fields.lua b/kong/clustering/compat/removed_fields.lua index 9893dd60cef..213c0e0e71e 100644 --- a/kong/clustering/compat/removed_fields.lua +++ b/kong/clustering/compat/removed_fields.lua @@ -115,6 +115,9 @@ return { opentelemetry = { "sampling_rate", }, + basic_auth = { + "realm" + } }, -- Any dataplane older than 3.7.0 @@ -135,5 +138,8 @@ return { ai_response_transformer = { "llm.model.options.upstream_path", }, + key_auth = { + "realm" + } }, } diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index a18774b331d..fcc6419d33b 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -93,8 +93,8 @@ local transformers_to = { return nil, nil, err end - messages.temperature = request_table.temperature or (model.options and model.options.temperature) or nil - messages.max_tokens = request_table.max_tokens or (model.options and model.options.max_tokens) or nil + messages.temperature = (model.options and model.options.temperature) or request_table.temperature + messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens messages.model = model.name or request_table.model messages.stream = request_table.stream or false -- explicitly set this if nil @@ -110,9 +110,8 @@ local transformers_to = { return nil, nil, err end - prompt.temperature = request_table.temperature or (model.options and model.options.temperature) or nil - prompt.max_tokens_to_sample = request_table.max_tokens or (model.options and model.options.max_tokens) or nil - prompt.model = model.name + prompt.temperature = (model.options and model.options.temperature) or request_table.temperature + prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens prompt.model = model.name or request_table.model prompt.stream = request_table.stream or false -- explicitly set this if nil @@ -152,11 +151,9 @@ local function start_to_event(event_data, model_info) local metadata = { prompt_tokens = meta.usage - and meta.usage.input_tokens - or nil, + and meta.usage.input_tokens, completion_tokens = meta.usage - and meta.usage.output_tokens - or nil, + and meta.usage.output_tokens, model = meta.model, stop_reason = meta.stop_reason, stop_sequence = meta.stop_sequence, @@ -209,14 +206,11 @@ local function handle_stream_event(event_t, model_info, route_type) and event_data.usage then return nil, nil, { prompt_tokens = nil, - completion_tokens = event_data.usage.output_tokens - or nil, + completion_tokens = event_data.usage.output_tokens, stop_reason = event_data.delta - and event_data.delta.stop_reason - or nil, + and event_data.delta.stop_reason, stop_sequence = event_data.delta - and event_data.delta.stop_sequence - or nil, + and event_data.delta.stop_sequence, } else return nil, "message_delta is missing the metadata block", nil @@ -267,7 +261,7 @@ local transformers_from = { prompt_tokens = usage.input_tokens, completion_tokens = usage.output_tokens, total_tokens = usage.input_tokens and usage.output_tokens and - usage.input_tokens + usage.output_tokens or nil, + usage.input_tokens + usage.output_tokens, } else @@ -442,12 +436,7 @@ function _M.post_request(conf) end function _M.pre_request(conf, body) - -- check for user trying to bring own model - if body and body.model then - return nil, "cannot use own model for this instance" - end - - return true, nil + return true end -- returns err or nil diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 79aa0ca5010..b96cbbbc2d4 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -219,18 +219,15 @@ local transformers_from = { local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.output_tokens - or nil, + and response_table.meta.billed_units.output_tokens, prompt_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.input_tokens - or nil, + and response_table.meta.billed_units.input_tokens, total_tokens = response_table.meta and response_table.meta.billed_units - and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens) - or nil, + and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats @@ -252,18 +249,15 @@ local transformers_from = { local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.output_tokens - or nil, + and response_table.meta.billed_units.output_tokens, prompt_tokens = response_table.meta and response_table.meta.billed_units - and response_table.meta.billed_units.input_tokens - or nil, + and response_table.meta.billed_units.input_tokens, total_tokens = response_table.meta and response_table.meta.billed_units - and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens) - or nil, + and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats @@ -271,7 +265,7 @@ local transformers_from = { return nil, "'text' or 'generations' missing from cohere response body" end - + return cjson.encode(messages) end, @@ -299,11 +293,10 @@ local transformers_from = { prompt.id = response_table.id local stats = { - completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens or nil, - prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens or nil, + completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens, + prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens, total_tokens = response_table.meta - and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens) - or nil, + and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } prompt.usage = stats @@ -323,9 +316,9 @@ local transformers_from = { prompt.id = response_table.generation_id local stats = { - completion_tokens = response_table.token_count and response_table.token_count.response_tokens or nil, - prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens or nil, - total_tokens = response_table.token_count and response_table.token_count.total_tokens or nil, + completion_tokens = response_table.token_count and response_table.token_count.response_tokens, + prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens, + total_tokens = response_table.token_count and response_table.token_count.total_tokens, } prompt.usage = stats @@ -400,12 +393,7 @@ function _M.post_request(conf) end function _M.pre_request(conf, body) - -- check for user trying to bring own model - if body and body.model then - return false, "cannot use own model for this instance" - end - - return true, nil + return true end function _M.subrequest(body, conf, http_opts, return_res_table) @@ -467,7 +455,7 @@ end function _M.configure_request(conf) local parsed_url - if conf.model.options.upstream_url then + if conf.model.options and conf.model.options.upstream_url then parsed_url = socket_url.parse(conf.model.options.upstream_url) else parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) @@ -476,10 +464,6 @@ function _M.configure_request(conf) or ai_shared.operation_map[DRIVER_NAME][conf.route_type] and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path or "/" - - if not parsed_url.path then - return false, fmt("operation %s is not supported for cohere provider", conf.route_type) - end end -- if the path is read from a URL capture, ensure that it is valid diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index b08f29bc325..1c592e5ef60 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -18,7 +18,7 @@ end local transformers_to = { ["llm/v1/chat"] = function(request_table, model_info, route_type) - request_table.model = request_table.model or model_info.name + request_table.model = model_info.name or request_table.model request_table.stream = request_table.stream or false -- explicitly set this request_table.top_k = nil -- explicitly remove unsupported default @@ -26,7 +26,7 @@ local transformers_to = { end, ["llm/v1/completions"] = function(request_table, model_info, route_type) - request_table.model = model_info.name + request_table.model = model_info.name or request_table.model request_table.stream = request_table.stream or false -- explicitly set this request_table.top_k = nil -- explicitly remove unsupported default diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 0b9cdcf3ab3..2e36a522142 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -131,10 +131,10 @@ _M.clear_response_headers = { -- @return {string} error if any is thrown - request should definitely be terminated if this is not nil function _M.merge_config_defaults(request, options, request_format) if options then - request.temperature = request.temperature or options.temperature - request.max_tokens = request.max_tokens or options.max_tokens - request.top_p = request.top_p or options.top_p - request.top_k = request.top_k or options.top_k + request.temperature = options.temperature or request.temperature + request.max_tokens = options.max_tokens or request.max_tokens + request.top_p = options.top_p or request.top_p + request.top_k = options.top_k or request.top_k end return request, nil @@ -197,28 +197,44 @@ end function _M.frame_to_events(frame) local events = {} - -- todo check if it's raw json and + -- Cohere / Other flat-JSON format parser -- just return the split up data frame - if string.sub(str_ltrim(frame), 1, 1) == "{" then + if (not kong or not kong.ctx.plugin.truncated_frame) and string.sub(str_ltrim(frame), 1, 1) == "{" then for event in frame:gmatch("[^\r\n]+") do events[#events + 1] = { data = event, } end else + -- standard SSE parser local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } - for _, dat in ipairs(event_lines) do + for i, dat in ipairs(event_lines) do if #dat < 1 then events[#events + 1] = struct struct = { event = nil, id = nil, data = nil } end + -- test for truncated chunk on the last line (no trailing \r\n\r\n) + if #dat > 0 and #event_lines == i then + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") + kong.ctx.plugin.truncated_frame = dat + break -- stop parsing immediately, server has done something wrong + end + + -- test for abnormal start-of-frame (truncation tail) + if kong and kong.ctx.plugin.truncated_frame then + -- this is the tail of a previous incomplete chunk + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") + dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) + kong.ctx.plugin.truncated_frame = nil + end + local s1, _ = str_find(dat, ":") -- find where the cut point is if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world + local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world -- for now not checking if the value is already been set @@ -249,7 +265,7 @@ function _M.to_ollama(request_table, model) -- common parameters input.stream = request_table.stream or false -- for future capability - input.model = model.name + input.model = model.name or request_table.name if model.options then input.options = {} @@ -603,8 +619,10 @@ end -- Function to count the number of words in a string local function count_words(str) local count = 0 - for word in str:gmatch("%S+") do - count = count + 1 + if type(str) == "string" then + for word in str:gmatch("%S+") do + count = count + 1 + end end return count end diff --git a/kong/llm/init.lua b/kong/llm/init.lua index af3833ff44f..aaf3af08a79 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -1,393 +1,225 @@ --- imports -local typedefs = require("kong.db.schema.typedefs") -local fmt = string.format -local cjson = require("cjson.safe") -local re_match = ngx.re.match local ai_shared = require("kong.llm.drivers.shared") --- - -local _M = {} - -local auth_schema = { - type = "record", - required = false, - fields = { - { header_name = { - type = "string", - description = "If AI model requires authentication via Authorization or API key header, specify its name here.", - required = false, - referenceable = true }}, - { header_value = { - type = "string", - description = "Specify the full auth header value for 'header_name', for example 'Bearer key' or just 'key'.", - required = false, - encrypted = true, -- [[ ee declaration ]] - referenceable = true }}, - { param_name = { - type = "string", - description = "If AI model requires authentication via query parameter, specify its name here.", - required = false, - referenceable = true }}, - { param_value = { - type = "string", - description = "Specify the full parameter value for 'param_name'.", - required = false, - encrypted = true, -- [[ ee declaration ]] - referenceable = true }}, - { param_location = { - type = "string", - description = "Specify whether the 'param_name' and 'param_value' options go in a query string, or the POST form/JSON body.", - required = false, - one_of = { "query", "body" } }}, - } -} +local re_match = ngx.re.match +local cjson = require("cjson.safe") +local fmt = string.format +local EMPTY = {} -local model_options_schema = { - description = "Key/value settings for the model", - type = "record", - required = false, - fields = { - { max_tokens = { - type = "integer", - description = "Defines the max_tokens, if using chat or completion models.", - required = false, - default = 256 }}, - { temperature = { - type = "number", - description = "Defines the matching temperature, if using chat or completion models.", - required = false, - between = { 0.0, 5.0 }}}, - { top_p = { - type = "number", - description = "Defines the top-p probability mass, if supported.", - required = false, - between = { 0, 1 }}}, - { top_k = { - type = "integer", - description = "Defines the top-k most likely tokens, if supported.", - required = false, - between = { 0, 500 }}}, - { anthropic_version = { - type = "string", - description = "Defines the schema/API version, if using Anthropic provider.", - required = false }}, - { azure_instance = { - type = "string", - description = "Instance name for Azure OpenAI hosted models.", - required = false }}, - { azure_api_version = { - type = "string", - description = "'api-version' for Azure OpenAI instances.", - required = false, - default = "2023-05-15" }}, - { azure_deployment_id = { - type = "string", - description = "Deployment ID for Azure OpenAI instances.", - required = false }}, - { llama2_format = { - type = "string", - description = "If using llama2 provider, select the upstream message format.", - required = false, - one_of = { "raw", "openai", "ollama" }}}, - { mistral_format = { - type = "string", - description = "If using mistral provider, select the upstream message format.", - required = false, - one_of = { "openai", "ollama" }}}, - { upstream_url = typedefs.url { - description = "Manually specify or override the full URL to the AI operation endpoints, " - .. "when calling (self-)hosted models, or for running via a private endpoint.", - required = false }}, - { upstream_path = { - description = "Manually specify or override the AI operation path, " - .. "used when e.g. using the 'preserve' route_type.", - type = "string", - required = false }}, - } -} -local model_schema = { - type = "record", - required = true, - fields = { - { provider = { - type = "string", description = "AI provider request format - Kong translates " - .. "requests to and from the specified backend compatible formats.", - required = true, - one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2" }}}, - { name = { - type = "string", - description = "Model name to execute.", - required = false }}, - { options = model_options_schema }, - } +-- The module table +local _M = { + config_schema = require "kong.llm.schemas", } -local logging_schema = { - type = "record", - required = true, - fields = { - { log_statistics = { - type = "boolean", - description = "If enabled and supported by the driver, " - .. "will add model usage and token metrics into the Kong log plugin(s) output.", - required = true, - default = false }}, - { log_payloads = { - type = "boolean", - description = "If enabled, will log the request and response body into the Kong log plugin(s) output.", - required = true, default = false }}, - } -} -local UNSUPPORTED_LOG_STATISTICS = { - ["llm/v1/completions"] = { ["anthropic"] = true }, -} -_M.config_schema = { - type = "record", - fields = { - { route_type = { - type = "string", - description = "The model's operation implementation, for this provider. " .. - "Set to `preserve` to pass through without transformation.", - required = true, - one_of = { "llm/v1/chat", "llm/v1/completions", "preserve" } }}, - { auth = auth_schema }, - { model = model_schema }, - { logging = logging_schema }, - }, - entity_checks = { - -- these three checks run in a chain, to ensure that all auth params for each respective "set" are specified - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "openai", "azure", "anthropic", "cohere" } }, - then_at_least_one_of = { "auth.header_name", "auth.param_name" }, - then_err = "must set one of %s, and its respective options, when provider is not self-hosted" }}, - - { mutually_required = { "auth.header_name", "auth.header_value" }, }, - { mutually_required = { "auth.param_name", "auth.param_value", "auth.param_location" }, }, - - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "llama2" } }, - then_at_least_one_of = { "model.options.llama2_format" }, - then_err = "must set %s for llama2 provider" }}, - - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "mistral" } }, - then_at_least_one_of = { "model.options.mistral_format" }, - then_err = "must set %s for mistral provider" }}, - - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "anthropic" } }, - then_at_least_one_of = { "model.options.anthropic_version" }, - then_err = "must set %s for anthropic provider" }}, - - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "azure" } }, - then_at_least_one_of = { "model.options.azure_instance" }, - then_err = "must set %s for azure provider" }}, - - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "azure" } }, - then_at_least_one_of = { "model.options.azure_api_version" }, - then_err = "must set %s for azure provider" }}, - - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "azure" } }, - then_at_least_one_of = { "model.options.azure_deployment_id" }, - then_err = "must set %s for azure provider" }}, - - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "mistral", "llama2" } }, - then_at_least_one_of = { "model.options.upstream_url" }, - then_err = "must set %s for self-hosted providers/models" }}, - - { - custom_entity_check = { - field_sources = { "route_type", "model", "logging" }, - fn = function(entity) - if entity.logging.log_statistics and UNSUPPORTED_LOG_STATISTICS[entity.route_type] - and UNSUPPORTED_LOG_STATISTICS[entity.route_type][entity.model.provider] then - return nil, fmt("%s does not support statistics when route_type is %s", - entity.model.provider, entity.route_type) - - else - return true - end - end, - } +do + -- formats_compatible is a map of formats that are compatible with each other. + local formats_compatible = { + ["llm/v1/chat"] = { + ["llm/v1/chat"] = true, }, - }, -} + ["llm/v1/completions"] = { + ["llm/v1/completions"] = true, + }, + } -local formats_compatible = { - ["llm/v1/chat"] = { - ["llm/v1/chat"] = true, - }, - ["llm/v1/completions"] = { - ["llm/v1/completions"] = true, - }, -} -local function identify_request(request) - -- primitive request format determination - local formats = {} - if request.messages - and type(request.messages) == "table" - and #request.messages > 0 - then - table.insert(formats, "llm/v1/chat") - end - - if request.prompt - and type(request.prompt) == "string" - then - table.insert(formats, "llm/v1/completions") - end + -- identify_request determines the format of the request. + -- It returns the format, or nil and an error message. + -- @tparam table request The request to identify + -- @treturn[1] string The format of the request + -- @treturn[2] nil + -- @treturn[2] string An error message if unidentified, or matching multiple formats + local function identify_request(request) + -- primitive request format determination + local formats = {} - if #formats > 1 then - return nil, "request matches multiple LLM request formats" - elseif not formats_compatible[formats[1]] then - return nil, "request format not recognised" - else - return formats[1] - end -end + if type(request.messages) == "table" and #request.messages > 0 then + table.insert(formats, "llm/v1/chat") + end -function _M.is_compatible(request, route_type) - if route_type == "preserve" then - return true - end + if type(request.prompt) == "string" then + table.insert(formats, "llm/v1/completions") + end - local format, err = identify_request(request) - if err then - return nil, err + if formats[2] then + return nil, "request matches multiple LLM request formats" + elseif not formats_compatible[formats[1] or false] then + return nil, "request format not recognised" + else + return formats[1] + end end - if formats_compatible[format][route_type] then - return true - end - return false, fmt("[%s] message format is not compatible with [%s] route type", format, route_type) -end -function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex_match) - local err, _ + --- Check if a request is compatible with a route type. + -- @tparam table request The request to check + -- @tparam string route_type The route type to check against, eg. "llm/v1/chat" + -- @treturn[1] boolean True if compatible + -- @treturn[2] boolean False if not compatible + -- @treturn[2] string Error message if not compatible + -- @treturn[3] nil + -- @treturn[3] string Error message if request format is not recognised + function _M.is_compatible(request, route_type) + if route_type == "preserve" then + return true + end - -- set up the request - local ai_request = { - messages = { - [1] = { - role = "system", - content = system_prompt, - }, - [2] = { - role = "user", - content = request, - } - }, - stream = false, - } + local format, err = identify_request(request) + if err then + return nil, err + end + + if formats_compatible[format][route_type] then + return true + end - -- convert it to the specified driver format - ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat") - if err then - return nil, err + return false, fmt("[%s] message format is not compatible with [%s] route type", format, route_type) end +end - -- run the shared logging/analytics/auth function - ai_shared.pre_request(self.conf, ai_request) - -- send it to the ai service - local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false) - if err then - return nil, "failed to introspect request with AI service: " .. err - end +do + ------------------------------------------------------------------------------ + -- LLM class implementation + ------------------------------------------------------------------------------ + local LLM = {} + LLM.__index = LLM - -- parse and convert the response - local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type) - if err then - return nil, "failed to convert AI response to Kong format: " .. err - end - -- run the shared logging/analytics function - ai_shared.post_request(self.conf, ai_response) - local ai_response, err = cjson.decode(ai_response) - if err then - return nil, "failed to convert AI response to JSON: " .. err - end + function LLM:ai_introspect_body(request, system_prompt, http_opts, response_regex_match) + local err, _ - local new_request_body = ai_response.choices - and #ai_response.choices > 0 - and ai_response.choices[1] - and ai_response.choices[1].message - and ai_response.choices[1].message.content - if not new_request_body then - return nil, "no 'choices' in upstream AI service response" - end + -- set up the request + local ai_request = { + messages = { + [1] = { + role = "system", + content = system_prompt, + }, + [2] = { + role = "user", + content = request, + } + }, + stream = false, + } - -- if specified, extract the first regex match from the AI response - -- this is useful for AI models that pad with assistant text, even when - -- we ask them NOT to. - if response_regex_match then - local matches, err = re_match(new_request_body, response_regex_match, "ijom") + -- convert it to the specified driver format + ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat") if err then - return nil, "failed regex matching ai response: " .. err + return nil, err end - if matches then - new_request_body = matches[0] -- this array DOES start at 0, for some reason + -- run the shared logging/analytics/auth function + ai_shared.pre_request(self.conf, ai_request) - else - return nil, "AI response did not match specified regular expression" + -- send it to the ai service + local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false) + if err then + return nil, "failed to introspect request with AI service: " .. err + end + -- parse and convert the response + local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type) + if err then + return nil, "failed to convert AI response to Kong format: " .. err end - end - return new_request_body -end + -- run the shared logging/analytics function + ai_shared.post_request(self.conf, ai_response) -function _M:parse_json_instructions(in_body) - local err - if type(in_body) == "string" then - in_body, err = cjson.decode(in_body) + local ai_response, err = cjson.decode(ai_response) if err then - return nil, nil, nil, err + return nil, "failed to convert AI response to JSON: " .. err end - end - if type(in_body) ~= "table" then - return nil, nil, nil, "input not table or string" + local new_request_body = ((ai_response.choices or EMPTY)[1].message or EMPTY).content + if not new_request_body then + return nil, "no 'choices' in upstream AI service response" + end + + -- if specified, extract the first regex match from the AI response + -- this is useful for AI models that pad with assistant text, even when + -- we ask them NOT to. + if response_regex_match then + local matches, err = re_match(new_request_body, response_regex_match, "ijom") + if err then + return nil, "failed regex matching ai response: " .. err + end + + if matches then + new_request_body = matches[0] -- this array DOES start at 0, for some reason + + else + return nil, "AI response did not match specified regular expression" + + end + end + + return new_request_body end - return - in_body.headers, - in_body.body or in_body, - in_body.status or 200 -end -function _M:new(conf, http_opts) - local o = {} - setmetatable(o, self) - self.__index = self - self.conf = conf or {} - self.http_opts = http_opts or {} + -- Parse the response instructions. + -- @tparam string|table in_body The response to parse, if a string, it will be parsed as JSON. + -- @treturn[1] table The headers, field `in_body.headers` + -- @treturn[1] string The body, field `in_body.body` (or if absent `in_body` itself as a table) + -- @treturn[1] number The status, field `in_body.status` (or 200 if absent) + -- @treturn[2] nil + -- @treturn[2] string An error message if parsing failed or input wasn't a table + function LLM:parse_json_instructions(in_body) + local err + if type(in_body) == "string" then + in_body, err = cjson.decode(in_body) + if err then + return nil, nil, nil, err + end + end + + if type(in_body) ~= "table" then + return nil, nil, nil, "input not table or string" + end + + return + in_body.headers, + in_body.body or in_body, + in_body.status or 200 + end - local driver = fmt("kong.llm.drivers.%s", conf - and conf.model - and conf.model.provider - or "NONE_SET") - self.driver = require(driver) - if not self.driver then - return nil, fmt("could not instantiate %s package", driver) + --- Instantiate a new LLM driver instance. + -- @tparam table conf Configuration table + -- @tparam table http_opts HTTP options table + -- @treturn[1] table A new LLM driver instance + -- @treturn[2] nil + -- @treturn[2] string An error message if instantiation failed + function _M.new_driver(conf, http_opts) + local self = { + conf = conf or {}, + http_opts = http_opts or {}, + } + setmetatable(self, LLM) + + local provider = (self.conf.model or {}).provider or "NONE_SET" + local driver_module = "kong.llm.drivers." .. provider + local ok + ok, self.driver = pcall(require, driver_module) + if not ok then + local err = "could not instantiate " .. driver_module .. " package" + kong.log.err(err) + return nil, err + end + + return self end - return o end + return _M diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua new file mode 100644 index 00000000000..37b5aaf3476 --- /dev/null +++ b/kong/llm/schemas/init.lua @@ -0,0 +1,226 @@ +local typedefs = require("kong.db.schema.typedefs") +local fmt = string.format + + + +local auth_schema = { + type = "record", + required = false, + fields = { + { header_name = { + type = "string", + description = "If AI model requires authentication via Authorization or API key header, specify its name here.", + required = false, + referenceable = true }}, + { header_value = { + type = "string", + description = "Specify the full auth header value for 'header_name', for example 'Bearer key' or just 'key'.", + required = false, + encrypted = true, -- [[ ee declaration ]] + referenceable = true }}, + { param_name = { + type = "string", + description = "If AI model requires authentication via query parameter, specify its name here.", + required = false, + referenceable = true }}, + { param_value = { + type = "string", + description = "Specify the full parameter value for 'param_name'.", + required = false, + encrypted = true, -- [[ ee declaration ]] + referenceable = true }}, + { param_location = { + type = "string", + description = "Specify whether the 'param_name' and 'param_value' options go in a query string, or the POST form/JSON body.", + required = false, + one_of = { "query", "body" } }}, + } +} + + + +local model_options_schema = { + description = "Key/value settings for the model", + type = "record", + required = false, + fields = { + { max_tokens = { + type = "integer", + description = "Defines the max_tokens, if using chat or completion models.", + required = false, + default = 256 }}, + { temperature = { + type = "number", + description = "Defines the matching temperature, if using chat or completion models.", + required = false, + between = { 0.0, 5.0 }}}, + { top_p = { + type = "number", + description = "Defines the top-p probability mass, if supported.", + required = false, + between = { 0, 1 }}}, + { top_k = { + type = "integer", + description = "Defines the top-k most likely tokens, if supported.", + required = false, + between = { 0, 500 }}}, + { anthropic_version = { + type = "string", + description = "Defines the schema/API version, if using Anthropic provider.", + required = false }}, + { azure_instance = { + type = "string", + description = "Instance name for Azure OpenAI hosted models.", + required = false }}, + { azure_api_version = { + type = "string", + description = "'api-version' for Azure OpenAI instances.", + required = false, + default = "2023-05-15" }}, + { azure_deployment_id = { + type = "string", + description = "Deployment ID for Azure OpenAI instances.", + required = false }}, + { llama2_format = { + type = "string", + description = "If using llama2 provider, select the upstream message format.", + required = false, + one_of = { "raw", "openai", "ollama" }}}, + { mistral_format = { + type = "string", + description = "If using mistral provider, select the upstream message format.", + required = false, + one_of = { "openai", "ollama" }}}, + { upstream_url = typedefs.url { + description = "Manually specify or override the full URL to the AI operation endpoints, " + .. "when calling (self-)hosted models, or for running via a private endpoint.", + required = false }}, + { upstream_path = { + description = "Manually specify or override the AI operation path, " + .. "used when e.g. using the 'preserve' route_type.", + type = "string", + required = false }}, + } +} + + + +local model_schema = { + type = "record", + required = true, + fields = { + { provider = { + type = "string", description = "AI provider request format - Kong translates " + .. "requests to and from the specified backend compatible formats.", + required = true, + one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2" }}}, + { name = { + type = "string", + description = "Model name to execute.", + required = false }}, + { options = model_options_schema }, + } +} + + + +local logging_schema = { + type = "record", + required = true, + fields = { + { log_statistics = { + type = "boolean", + description = "If enabled and supported by the driver, " + .. "will add model usage and token metrics into the Kong log plugin(s) output.", + required = true, + default = false }}, + { log_payloads = { + type = "boolean", + description = "If enabled, will log the request and response body into the Kong log plugin(s) output.", + required = true, default = false }}, + } +} + + + +local UNSUPPORTED_LOG_STATISTICS = { + ["llm/v1/completions"] = { ["anthropic"] = true }, +} + + + +return { + type = "record", + fields = { + { route_type = { + type = "string", + description = "The model's operation implementation, for this provider. " .. + "Set to `preserve` to pass through without transformation.", + required = true, + one_of = { "llm/v1/chat", "llm/v1/completions", "preserve" } }}, + { auth = auth_schema }, + { model = model_schema }, + { logging = logging_schema }, + }, + entity_checks = { + -- these three checks run in a chain, to ensure that all auth params for each respective "set" are specified + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "openai", "azure", "anthropic", "cohere" } }, + then_at_least_one_of = { "auth.header_name", "auth.param_name" }, + then_err = "must set one of %s, and its respective options, when provider is not self-hosted" }}, + + { mutually_required = { "auth.header_name", "auth.header_value" }, }, + { mutually_required = { "auth.param_name", "auth.param_value", "auth.param_location" }, }, + + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "llama2" } }, + then_at_least_one_of = { "model.options.llama2_format" }, + then_err = "must set %s for llama2 provider" }}, + + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "mistral" } }, + then_at_least_one_of = { "model.options.mistral_format" }, + then_err = "must set %s for mistral provider" }}, + + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "anthropic" } }, + then_at_least_one_of = { "model.options.anthropic_version" }, + then_err = "must set %s for anthropic provider" }}, + + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "azure" } }, + then_at_least_one_of = { "model.options.azure_instance" }, + then_err = "must set %s for azure provider" }}, + + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "azure" } }, + then_at_least_one_of = { "model.options.azure_api_version" }, + then_err = "must set %s for azure provider" }}, + + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "azure" } }, + then_at_least_one_of = { "model.options.azure_deployment_id" }, + then_err = "must set %s for azure provider" }}, + + { conditional_at_least_one_of = { if_field = "model.provider", + if_match = { one_of = { "mistral", "llama2" } }, + then_at_least_one_of = { "model.options.upstream_url" }, + then_err = "must set %s for self-hosted providers/models" }}, + + { + custom_entity_check = { + field_sources = { "route_type", "model", "logging" }, + fn = function(entity) + if entity.logging.log_statistics and UNSUPPORTED_LOG_STATISTICS[entity.route_type] + and UNSUPPORTED_LOG_STATISTICS[entity.route_type][entity.model.provider] then + return nil, fmt("%s does not support statistics when route_type is %s", + entity.model.provider, entity.route_type) + + else + return true + end + end, + } + }, + }, +} diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 739c33f0667..35e13fbe8d9 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -1,6 +1,3 @@ -local _M = {} - --- imports local ai_shared = require("kong.llm.drivers.shared") local llm = require("kong.llm") local cjson = require("cjson.safe") @@ -8,52 +5,41 @@ local kong_utils = require("kong.tools.gzip") local kong_meta = require("kong.meta") local buffer = require "string.buffer" local strip = require("kong.tools.utils").strip --- -_M.PRIORITY = 770 -_M.VERSION = kong_meta.version +local EMPTY = {} --- reuse this table for error message response -local ERROR_MSG = { error = { message = "" } } +local _M = { + PRIORITY = 770, + VERSION = kong_meta.version +} -local function bad_request(msg) - kong.log.warn(msg) - ERROR_MSG.error.message = msg - return kong.response.exit(400, ERROR_MSG) +--- Return a 400 response with a JSON body. This function is used to +-- return errors to the client while also logging the error. +local function bad_request(msg) + kong.log.info(msg) + return kong.response.exit(400, { error = { message = msg } }) end -local function internal_server_error(msg) - kong.log.err(msg) - ERROR_MSG.error.message = msg - - return kong.response.exit(500, ERROR_MSG) -end +-- get the token text from an event frame local function get_token_text(event_t) - -- chat - return - event_t and - event_t.choices and - #event_t.choices > 0 and - event_t.choices[1].delta and - event_t.choices[1].delta.content - - or - - -- completions - event_t and - event_t.choices and - #event_t.choices > 0 and - event_t.choices[1].text - - or "" + -- get: event_t.choices[1] + local first_choice = ((event_t or EMPTY).choices or EMPTY)[1] or EMPTY + -- return: + -- - event_t.choices[1].delta.content + -- - event_t.choices[1].text + -- - "" + local token_text = (first_choice.delta or EMPTY).content or first_choice.text or "" + return (type(token_text) == "string" and token_text) or "" end + + local function handle_streaming_frame(conf) -- make a re-usable framebuffer local framebuffer = buffer.new() @@ -61,11 +47,11 @@ local function handle_streaming_frame(conf) local ai_driver = require("kong.llm.drivers." .. conf.model.provider) + local kong_ctx_plugin = kong.ctx.plugin -- create a buffer to store each response token/frame, on first pass - if conf.logging - and conf.logging.log_payloads - and (not kong.ctx.plugin.ai_stream_log_buffer) then - kong.ctx.plugin.ai_stream_log_buffer = buffer.new() + if (conf.logging or EMPTY).log_payloads and + (not kong_ctx_plugin.ai_stream_log_buffer) then + kong_ctx_plugin.ai_stream_log_buffer = buffer.new() end -- now handle each chunk/frame @@ -101,7 +87,7 @@ local function handle_streaming_frame(conf) token_t = get_token_text(event_t) end - kong.ctx.plugin.ai_stream_log_buffer:put(token_t) + kong_ctx_plugin.ai_stream_log_buffer:put(token_t) end end @@ -122,8 +108,8 @@ local function handle_streaming_frame(conf) -- but this is all we can do until OpenAI fixes this... -- -- essentially, every 4 characters is a token, with minimum of 1*4 per event - kong.ctx.plugin.ai_stream_completion_tokens = - (kong.ctx.plugin.ai_stream_completion_tokens or 0) + math.ceil(#strip(token_t) / 4) + kong_ctx_plugin.ai_stream_completion_tokens = + (kong_ctx_plugin.ai_stream_completion_tokens or 0) + math.ceil(#strip(token_t) / 4) end end end @@ -135,14 +121,14 @@ local function handle_streaming_frame(conf) end if conf.logging and conf.logging.log_statistics and metadata then - kong.ctx.plugin.ai_stream_completion_tokens = - (kong.ctx.plugin.ai_stream_completion_tokens or 0) + + kong_ctx_plugin.ai_stream_completion_tokens = + (kong_ctx_plugin.ai_stream_completion_tokens or 0) + (metadata.completion_tokens or 0) - or kong.ctx.plugin.ai_stream_completion_tokens - kong.ctx.plugin.ai_stream_prompt_tokens = - (kong.ctx.plugin.ai_stream_prompt_tokens or 0) + + or kong_ctx_plugin.ai_stream_completion_tokens + kong_ctx_plugin.ai_stream_prompt_tokens = + (kong_ctx_plugin.ai_stream_prompt_tokens or 0) + (metadata.prompt_tokens or 0) - or kong.ctx.plugin.ai_stream_prompt_tokens + or kong_ctx_plugin.ai_stream_prompt_tokens end end end @@ -156,23 +142,26 @@ local function handle_streaming_frame(conf) if finished then local fake_response_t = { - response = kong.ctx.plugin.ai_stream_log_buffer and kong.ctx.plugin.ai_stream_log_buffer:get(), + response = kong_ctx_plugin.ai_stream_log_buffer and kong_ctx_plugin.ai_stream_log_buffer:get(), usage = { - prompt_tokens = kong.ctx.plugin.ai_stream_prompt_tokens or 0, - completion_tokens = kong.ctx.plugin.ai_stream_completion_tokens or 0, - total_tokens = (kong.ctx.plugin.ai_stream_prompt_tokens or 0) - + (kong.ctx.plugin.ai_stream_completion_tokens or 0), + prompt_tokens = kong_ctx_plugin.ai_stream_prompt_tokens or 0, + completion_tokens = kong_ctx_plugin.ai_stream_completion_tokens or 0, + total_tokens = (kong_ctx_plugin.ai_stream_prompt_tokens or 0) + + (kong_ctx_plugin.ai_stream_completion_tokens or 0), } } ngx.arg[1] = nil ai_shared.post_request(conf, fake_response_t) - kong.ctx.plugin.ai_stream_log_buffer = nil + kong_ctx_plugin.ai_stream_log_buffer = nil end end function _M:header_filter(conf) - if kong.ctx.shared.skip_response_transformer then + local kong_ctx_plugin = kong.ctx.plugin + local kong_ctx_shared = kong.ctx.shared + + if kong_ctx_shared.skip_response_transformer then return end @@ -187,7 +176,7 @@ function _M:header_filter(conf) end -- we use openai's streaming mode (SSE) - if kong.ctx.shared.ai_proxy_streaming_mode then + if kong_ctx_shared.ai_proxy_streaming_mode then -- we are going to send plaintext event-stream frames for ALL models kong.response.set_header("Content-Type", "text/event-stream") return @@ -204,28 +193,26 @@ function _M:header_filter(conf) -- if this is a 'streaming' request, we can't know the final -- result of the response body, so we just proceed to body_filter -- to translate each SSE event frame - if not kong.ctx.shared.ai_proxy_streaming_mode then + if not kong_ctx_shared.ai_proxy_streaming_mode then local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" if is_gzip then response_body = kong_utils.inflate_gzip(response_body) end if route_type == "preserve" then - kong.ctx.plugin.parsed_response = response_body + kong_ctx_plugin.parsed_response = response_body else local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) if err then - kong.ctx.plugin.ai_parser_error = true - + kong_ctx_plugin.ai_parser_error = true + ngx.status = 500 - ERROR_MSG.error.message = err - - kong.ctx.plugin.parsed_response = cjson.encode(ERROR_MSG) - + kong_ctx_plugin.parsed_response = cjson.encode({ error = { message = err } }) + elseif new_response_string then -- preserve the same response content type; assume the from_format function -- has returned the body in the appropriate response output format - kong.ctx.plugin.parsed_response = new_response_string + kong_ctx_plugin.parsed_response = new_response_string end end end @@ -235,17 +222,20 @@ end function _M:body_filter(conf) + local kong_ctx_plugin = kong.ctx.plugin + local kong_ctx_shared = kong.ctx.shared + -- if body_filter is called twice, then return - if kong.ctx.plugin.body_called and not kong.ctx.shared.ai_proxy_streaming_mode then + if kong_ctx_plugin.body_called and not kong_ctx_shared.ai_proxy_streaming_mode then return end local route_type = conf.route_type - if kong.ctx.shared.skip_response_transformer and (route_type ~= "preserve") then + if kong_ctx_shared.skip_response_transformer and (route_type ~= "preserve") then local response_body - if kong.ctx.shared.parsed_response then - response_body = kong.ctx.shared.parsed_response + if kong_ctx_shared.parsed_response then + response_body = kong_ctx_shared.parsed_response elseif kong.response.get_status() == 200 then response_body = kong.service.response.get_raw_body() if not response_body then @@ -261,7 +251,7 @@ function _M:body_filter(conf) local ai_driver = require("kong.llm.drivers." .. conf.model.provider) local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) - + if err then kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err) elseif new_response_string then @@ -269,20 +259,20 @@ function _M:body_filter(conf) end end - if not kong.ctx.shared.skip_response_transformer then - if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then + if not kong_ctx_shared.skip_response_transformer then + if (kong.response.get_status() ~= 200) and (not kong_ctx_plugin.ai_parser_error) then return end if route_type ~= "preserve" then - if kong.ctx.shared.ai_proxy_streaming_mode then + if kong_ctx_shared.ai_proxy_streaming_mode then handle_streaming_frame(conf) else -- all errors MUST be checked and returned in header_filter -- we should receive a replacement response body from the same thread - local original_request = kong.ctx.plugin.parsed_response + local original_request = kong_ctx_plugin.parsed_response local deflated_request = original_request - + if deflated_request then local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" if is_gzip then @@ -298,26 +288,29 @@ function _M:body_filter(conf) kong.log.warn("analytics phase failed for request, ", err) end end - end + end end - kong.ctx.plugin.body_called = true + kong_ctx_plugin.body_called = true end function _M:access(conf) + local kong_ctx_plugin = kong.ctx.plugin + local kong_ctx_shared = kong.ctx.shared + -- store the route_type in ctx for use in response parsing local route_type = conf.route_type - kong.ctx.plugin.operation = route_type + kong_ctx_plugin.operation = route_type local request_table local multipart = false -- we may have received a replacement / decorated request body from another AI plugin - if kong.ctx.shared.replacement_request then + if kong_ctx_shared.replacement_request then kong.log.debug("replacement request body received from another AI plugin") - request_table = kong.ctx.shared.replacement_request + request_table = kong_ctx_shared.replacement_request else -- first, calculate the coordinates of the request @@ -342,24 +335,32 @@ function _M:access(conf) -- copy from the user request if present if (not multipart) and (not conf_m.model.name) and (request_table.model) then - conf_m.model.name = request_table.model + if type(request_table.model) == "string" then + conf_m.model.name = request_table.model + end elseif multipart then conf_m.model.name = "NOT_SPECIFIED" end + -- check that the user isn't trying to override the plugin conf model in the request body + if request_table and request_table.model and type(request_table.model) == "string" and request_table.model ~= "" then + if request_table.model ~= conf_m.model.name then + return bad_request("cannot use own model - must be: " .. conf_m.model.name) + end + end + -- model is stashed in the copied plugin conf, for consistency in transformation functions if not conf_m.model.name then return bad_request("model parameter not found in request, nor in gateway configuration") end - -- stash for analytics later - kong.ctx.plugin.llm_model_requested = conf_m.model.name + kong_ctx_plugin.llm_model_requested = conf_m.model.name -- check the incoming format is the same as the configured LLM format if not multipart then local compatible, err = llm.is_compatible(request_table, route_type) if not compatible then - kong.ctx.shared.skip_response_transformer = true + kong_ctx_shared.skip_response_transformer = true return bad_request(err) end end @@ -367,7 +368,7 @@ function _M:access(conf) -- check the incoming format is the same as the configured LLM format local compatible, err = llm.is_compatible(request_table, route_type) if not compatible then - kong.ctx.shared.skip_response_transformer = true + kong_ctx_shared.skip_response_transformer = true return bad_request(err) end @@ -384,17 +385,18 @@ function _M:access(conf) end -- store token cost estimate, on first pass - if not kong.ctx.plugin.ai_stream_prompt_tokens then + if not kong_ctx_plugin.ai_stream_prompt_tokens then local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8) if err then - return internal_server_error("unable to estimate request token cost: " .. err) + kong.log.err("unable to estimate request token cost: ", err) + return kong.response.exit(500) end - kong.ctx.plugin.ai_stream_prompt_tokens = prompt_tokens + kong_ctx_plugin.ai_stream_prompt_tokens = prompt_tokens end -- specific actions need to skip later for this to work - kong.ctx.shared.ai_proxy_streaming_mode = true + kong_ctx_shared.ai_proxy_streaming_mode = true else kong.service.request.enable_buffering() @@ -414,7 +416,7 @@ function _M:access(conf) -- transform the body to Kong-format for this provider/model parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type) if err then - kong.ctx.shared.skip_response_transformer = true + kong_ctx_shared.skip_response_transformer = true return bad_request(err) end end @@ -432,8 +434,9 @@ function _M:access(conf) -- now re-configure the request for this operation type local ok, err = ai_driver.configure_request(conf_m) if not ok then - kong.ctx.shared.skip_response_transformer = true - return internal_server_error(err) + kong_ctx_shared.skip_response_transformer = true + kong.log.err("failed to configure request for AI service: ", err) + return kong.response.exit(500) end -- lights out, and away we go diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index 9517be36632..0eb5cd89d8f 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -31,7 +31,7 @@ local function create_http_opts(conf) http_opts.proxy_opts = http_opts.proxy_opts or {} http_opts.proxy_opts.https_proxy = fmt("http://%s:%d", conf.https_proxy_host, conf.https_proxy_port) end - + http_opts.http_timeout = conf.http_timeout http_opts.https_verify = conf.https_verify @@ -46,15 +46,15 @@ function _M:access(conf) local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id conf.llm.__key__ = conf.__key__ - local ai_driver, err = llm:new(conf.llm, http_opts) - + local ai_driver, err = llm.new_driver(conf.llm, http_opts) + if not ai_driver then return internal_server_error(err) end -- if asked, introspect the request before proxying kong.log.debug("introspecting request with LLM") - local new_request_body, err = llm:ai_introspect_body( + local new_request_body, err = ai_driver:ai_introspect_body( kong.request.get_raw_body(), conf.prompt, http_opts, @@ -64,7 +64,7 @@ function _M:access(conf) if err then return bad_request(err) end - + -- set the body for later plugins kong.service.request.set_raw_body(new_request_body) diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index 7014d893852..94a82a5ff2d 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -21,6 +21,8 @@ local function internal_server_error(msg) return kong.response.exit(500, { error = { message = msg } }) end + + local function subrequest(httpc, request_body, http_opts) httpc:set_timeouts(http_opts.http_timeout or 60000) @@ -72,6 +74,8 @@ local function subrequest(httpc, request_body, http_opts) return res end + + local function create_http_opts(conf) local http_opts = {} @@ -84,13 +88,15 @@ local function create_http_opts(conf) http_opts.proxy_opts = http_opts.proxy_opts or {} http_opts.proxy_opts.https_proxy = fmt("http://%s:%d", conf.https_proxy_host, conf.https_proxy_port) end - + http_opts.http_timeout = conf.http_timeout http_opts.https_verify = conf.https_verify return http_opts end + + function _M:access(conf) kong.service.request.enable_buffering() kong.ctx.shared.skip_response_transformer = true @@ -99,8 +105,8 @@ function _M:access(conf) local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id conf.llm.__key__ = conf.__key__ - local ai_driver, err = llm:new(conf.llm, http_opts) - + local ai_driver, err = llm.new_driver(conf.llm, http_opts) + if not ai_driver then return internal_server_error(err) end @@ -123,7 +129,7 @@ function _M:access(conf) -- if asked, introspect the request before proxying kong.log.debug("introspecting response with LLM") - local new_response_body, err = llm:ai_introspect_body( + local new_response_body, err = ai_driver:ai_introspect_body( res_body, conf.prompt, http_opts, @@ -142,7 +148,7 @@ function _M:access(conf) local headers, body, status if conf.parse_llm_response_json_instructions then - headers, body, status, err = llm:parse_json_instructions(new_response_body) + headers, body, status, err = ai_driver:parse_json_instructions(new_response_body) if err then return internal_server_error("failed to parse JSON response instructions from AI backend: " .. err) end diff --git a/scripts/explain_manifest/main.py b/scripts/explain_manifest/main.py index 1033057d350..44f9dcc00fc 100755 --- a/scripts/explain_manifest/main.py +++ b/scripts/explain_manifest/main.py @@ -84,9 +84,12 @@ def gather_files(path: str, image: str): code = os.system( "ar p %s data.tar.gz | tar -C %s -xz" % (path, t.name)) elif ext == ".rpm": - # GNU cpio and rpm2cpio is needed + # rpm2cpio is needed + # rpm2archive ships with rpm2cpio on debians code = os.system( - "rpm2cpio %s | cpio --no-preserve-owner --no-absolute-filenames -idm -D %s" % (path, t.name)) + """ + rpm2archive %s && tar -C %s -xf %s.tgz + """ % (path, t.name, path)) elif ext == ".gz": code = os.system("tar -C %s -xf %s" % (t.name, path)) diff --git a/scripts/upgrade-tests/test-upgrade-path.sh b/scripts/upgrade-tests/test-upgrade-path.sh index 8144fd9513f..878c4c2f907 100755 --- a/scripts/upgrade-tests/test-upgrade-path.sh +++ b/scripts/upgrade-tests/test-upgrade-path.sh @@ -166,7 +166,7 @@ function run_tests() { echo ">> Setting up tests" docker exec -w /upgrade-test $OLD_CONTAINER $BUSTED_ENV /kong/bin/busted -t setup $TEST echo ">> Running migrations" - kong migrations up + kong migrations up --force echo ">> Testing old_after_up,all_phases" docker exec -w /upgrade-test $OLD_CONTAINER $BUSTED_ENV /kong/bin/busted -t old_after_up,all_phases $TEST echo ">> Testing new_after_up,all_phases" diff --git a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua index 4cfa96efea6..7f1322b14d1 100644 --- a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua +++ b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua @@ -586,6 +586,39 @@ describe("CP/DP config compat transformations #" .. strategy, function() admin.plugins:remove({ id = ai_response_transformer.id }) end) end) + + describe("www-authenticate header in plugins (realm config)", function() + it("[basic-auth] removes realm for versions below 3.6", function() + local basic_auth = admin.plugins:insert { + name = "basic-auth", + } + + local expected_basic_auth_prior_36 = utils.cycle_aware_deep_copy(basic_auth) + expected_basic_auth_prior_36.config.realm = nil + + do_assert(utils.uuid(), "3.5.0", expected_basic_auth_prior_36) + + -- cleanup + admin.plugins:remove({ id = basic_auth.id }) + end) + + it("[key-auth] removes realm for versions below 3.7", function() + local key_auth = admin.plugins:insert { + name = "key-auth", + config = { + realm = "test" + } + } + + local expected_key_auth_prior_37 = utils.cycle_aware_deep_copy(key_auth) + expected_key_auth_prior_37.config.realm = nil + + do_assert(utils.uuid(), "3.6.0", expected_key_auth_prior_37) + + -- cleanup + admin.plugins:remove({ id = key_auth.id }) + end) + end) end) end) diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 9ff754a1407..c1dfadfb4ac 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -629,7 +629,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS, { max_tokens = 1024, - top_p = 1.0, + top_p = 0.5, }, "llm/v1/chat" ) @@ -638,9 +638,9 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.is_nil(err) assert.same({ - max_tokens = 256, + max_tokens = 1024, temperature = 0.1, - top_p = 0.2, + top_p = 0.5, some_extra_param = "string_val", another_extra_param = 0.5, }, formatted) diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index e9fb74c3114..c218353bdb2 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -841,6 +841,52 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, json.choices[1].message) end) + it("good request, parses model of cjson.null", function() + local body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json") + body = cjson.decode(body) + body.model = cjson.null + body = cjson.encode(body) + + local r = client:get("/openai/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = body, + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "gpt-3.5-turbo-0613") + assert.equals(json.object, "chat.completion") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("tries to override configured model", function() + local r = client:get("/openai/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"), + }) + + local body = assert.res_status(400 , r) + local json = cjson.decode(body) + + assert.same(json, {error = { message = "cannot use own model - must be: gpt-3.5-turbo" } }) + end) + it("bad upstream response", function() local r = client:get("/openai/llm/v1/chat/bad_upstream_response", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua index 7dc325d8f8f..0d78e57b778 100644 --- a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -90,6 +90,59 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } + location = "/openai/llm/v1/chat/partial" { + content_by_lua_block { + local _EVENT_CHUNKS = { + [1] = 'data: { "choices": [ { "delta": { "content": "", "role": "assistant" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [2] = 'data: { "choices": [ { "delta": { "content": "The " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}\n\ndata: { "choices": [ { "delta": { "content": "answer " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [3] = 'data: { "choices": [ { "delta": { "content": "to 1 + " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Ts', + [4] = 'w1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [5] = 'data: { "choices": [ { "delta": { "content": "1 is " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}\n\ndata: { "choices": [ { "delta": { "content": "2." }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [6] = 'data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [7] = 'data: [DONE]', + } + + local fmt = string.format + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local token = ngx.req.get_headers()["authorization"] + local token_query = ngx.req.get_uri_args()["apikey"] + + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + else + -- GOOD RESPONSE + + ngx.status = 200 + ngx.header["Content-Type"] = "text/event-stream" + + for i, EVENT in ipairs(_EVENT_CHUNKS) do + -- pretend to truncate chunks + if _EVENT_CHUNKS[i+1] and _EVENT_CHUNKS[i+1]:sub(1, 5) ~= "data:" then + ngx.print(EVENT) + else + ngx.print(fmt("%s\n\n", EVENT)) + end + end + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/unauthorized.json")) + end + } + } + location = "/cohere/llm/v1/chat/good" { content_by_lua_block { local _EVENT_CHUNKS = { @@ -291,6 +344,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } -- + -- 200 chat openai - PARTIAL SPLIT CHUNKS + local openai_chat_partial = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/chat/partial" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = openai_chat_partial.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/openai/llm/v1/chat/partial" + }, + }, + }, + } + -- + -- 200 chat cohere local cohere_chat_good = assert(bp.routes:insert { service = empty_service, @@ -489,6 +571,69 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.equal(buf:tostring(), "The answer to 1 + 1 is 2.") end) + it("good stream request openai with partial split chunks", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + assert.is_nil(err) + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/openai/llm/v1/chat/partial", + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + assert.is_nil(err) + end + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + local buf = require("string.buffer").new() + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + assert.is_falsy(err and err ~= "closed") + end + + if buffer then + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + local s_copy = s + s_copy = string.sub(s_copy,7) + s_copy = cjson.decode(s_copy) + + buf:put(s_copy + and s_copy.choices + and s_copy.choices + and s_copy.choices[1] + and s_copy.choices[1].delta + and s_copy.choices[1].delta.content + or "") + + table.insert(events, s) + end + end + until not buffer + + assert.equal(#events, 8) + assert.equal(buf:tostring(), "The answer to 1 + 1 is 2.") + end) + it("good stream request cohere", function() local httpc = http.new() diff --git a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua index 51d84a43992..cc64fc489f6 100644 --- a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua @@ -1,4 +1,4 @@ -local llm_class = require("kong.llm") +local llm = require("kong.llm") local helpers = require "spec.helpers" local cjson = require "cjson" local http_mock = require "spec.helpers.http_mock" @@ -224,10 +224,10 @@ describe(PLUGIN_NAME .. ": (unit)", function() for name, format_options in pairs(FORMATS) do describe(name .. " transformer tests, exact json response", function() it("transforms request based on LLM instructions", function() - local llm = llm_class:new(format_options, {}) - assert.truthy(llm) + local llmdriver = llm.new_driver(format_options, {}) + assert.truthy(llmdriver) - local result, err = llm:ai_introspect_body( + local result, err = llmdriver:ai_introspect_body( REQUEST_BODY, -- request body SYSTEM_PROMPT, -- conf.prompt {}, -- http opts @@ -246,10 +246,10 @@ describe(PLUGIN_NAME .. ": (unit)", function() describe("openai transformer tests, pattern matchers", function() it("transforms request based on LLM instructions, with json extraction pattern", function() - local llm = llm_class:new(OPENAI_NOT_JSON, {}) - assert.truthy(llm) + local llmdriver = llm.new_driver(OPENAI_NOT_JSON, {}) + assert.truthy(llmdriver) - local result, err = llm:ai_introspect_body( + local result, err = llmdriver:ai_introspect_body( REQUEST_BODY, -- request body SYSTEM_PROMPT, -- conf.prompt {}, -- http opts @@ -265,10 +265,10 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) it("transforms request based on LLM instructions, but fails to match pattern", function() - local llm = llm_class:new(OPENAI_NOT_JSON, {}) - assert.truthy(llm) + local llmdriver = llm.new_driver(OPENAI_NOT_JSON, {}) + assert.truthy(llmdriver) - local result, err = llm:ai_introspect_body( + local result, err = llmdriver:ai_introspect_body( REQUEST_BODY, -- request body SYSTEM_PROMPT, -- conf.prompt {}, -- http opts diff --git a/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua b/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua index d436ad53644..7f4e544ecc3 100644 --- a/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua @@ -1,4 +1,4 @@ -local llm_class = require("kong.llm") +local llm = require("kong.llm") local helpers = require "spec.helpers" local cjson = require "cjson" local http_mock = require "spec.helpers.http_mock" @@ -90,10 +90,10 @@ describe(PLUGIN_NAME .. ": (unit)", function() describe("openai transformer tests, specific response", function() it("transforms request based on LLM instructions, with response transformation instructions format", function() - local llm = llm_class:new(OPENAI_INSTRUCTIONAL_RESPONSE, {}) - assert.truthy(llm) + local llmdriver = llm.new_driver(OPENAI_INSTRUCTIONAL_RESPONSE, {}) + assert.truthy(llmdriver) - local result, err = llm:ai_introspect_body( + local result, err = llmdriver:ai_introspect_body( REQUEST_BODY, -- request body SYSTEM_PROMPT, -- conf.prompt {}, -- http opts @@ -107,14 +107,14 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same(EXPECTED_RESULT, table_result) -- parse in response string format - local headers, body, status, err = llm:parse_json_instructions(result) + local headers, body, status, err = llmdriver:parse_json_instructions(result) assert.is_nil(err) assert.same({ ["content-type"] = "application/xml" }, headers) assert.same(209, status) assert.same(EXPECTED_RESULT.body, body) -- parse in response table format - headers, body, status, err = llm:parse_json_instructions(table_result) + headers, body, status, err = llmdriver:parse_json_instructions(table_result) assert.is_nil(err) assert.same({ ["content-type"] = "application/xml" }, headers) assert.same(209, status) diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua index 800100c9a67..13be816735a 100644 --- a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -210,7 +210,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then server { server_name llm; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; location ~/instructions { @@ -237,7 +237,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/badrequest" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) } @@ -246,7 +246,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/internalservererror" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) @@ -357,7 +357,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -379,7 +379,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = REQUEST_BODY, }) - + local body = assert.res_status(209 , r) assert.same(EXPECTED_RESULT.body, body) assert.same(r.headers["content-type"], "application/xml") @@ -393,7 +393,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = REQUEST_BODY, }) - + local body = assert.res_status(200 , r) local body_table, err = cjson.decode(body) assert.is_nil(err) @@ -431,7 +431,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = REQUEST_BODY, }) - + local body = assert.res_status(500 , r) local body_table, err = cjson.decode(body) assert.is_nil(err)