diff --git a/NEWS.md b/NEWS.md index 406fa99..eebabc0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,9 @@ * `chat_openai()` should be less likely to timeout when not streaming chat results (#213). +* `chat_azure()` now has a `credentials` argument to make authentication against + Azure more flexible (#248, @atheriel). + # ellmer 0.1.0 * New `chat_vllm()` to chat with models served by vLLM (#140). diff --git a/R/provider-azure.R b/R/provider-azure.R index 7b92b82..77d784b 100644 --- a/R/provider-azure.R +++ b/R/provider-azure.R @@ -16,11 +16,16 @@ NULL #' value of the `AZURE_OPENAI_ENDPOINT` envinronment variable. #' @param deployment_id Deployment id for the model you want to use. #' @param api_version The API version to use. -#' @param api_key The API key to use for authentication. You generally should -#' not supply this directly, but instead set the `AZURE_OPENAI_API_KEY` environment -#' variable. -#' @param token Azure token for authentication. This is typically not required for -#' Azure OpenAI API calls, but can be used if your setup requires it. +#' @param api_key An API key to use for authentication. You generally should not +#' supply this directly, but instead set the `AZURE_OPENAI_API_KEY` +#' environment variable. +#' @param token A literal Azure token to use for authentication. +#' @param credentials A list of authentication headers to pass into +#' [`httr2::req_headers()`], a function that returns them, or `NULL` to use +#' `token` or `api_key` to generate these headers instead. This is an escape +#' hatch that allows users to incorporate Azure credentials generated by other +#' packages into \pkg{ellmer}, or to manage the lifetime of credentials that +#' need to be refreshed. #' @inheritParams chat_openai #' @inherit chat_openai return #' @export @@ -34,26 +39,32 @@ chat_azure <- function(endpoint = azure_endpoint(), api_version = NULL, system_prompt = NULL, turns = NULL, - api_key = azure_key(), + api_key = NULL, token = NULL, + credentials = NULL, api_args = list(), echo = c("none", "text", "all")) { check_string(endpoint) check_string(deployment_id) api_version <- set_default(api_version, "2024-06-01") turns <- normalize_turns(turns, system_prompt) + check_exclusive(api_key, token, credentials, .require = FALSE) + check_string(api_key, allow_null = TRUE) + check_string(token, allow_null = TRUE) echo <- check_echo(echo) - - base_url <- paste0(endpoint, "/openai/deployments/", deployment_id) + if (is_list(credentials)) { + static_credentials <- force(credentials) + credentials <- function() static_credentials + } + check_function(credentials, allow_null = TRUE) + credentials <- credentials %||% default_azure_credentials(api_key, token) provider <- ProviderAzure( - base_url = base_url, endpoint = endpoint, - model = deployment_id, + deployment_id = deployment_id, api_version = api_version, - token = token, - extra_args = api_args, - api_key = api_key + credentials = credentials, + extra_args = api_args ) Chat$new(provider = provider, turns = turns, echo = echo) } @@ -61,19 +72,26 @@ chat_azure <- function(endpoint = azure_endpoint(), ProviderAzure <- new_class( "ProviderAzure", parent = ProviderOpenAI, + constructor = function(endpoint, deployment_id, api_version, credentials, + extra_args = list()) { + new_object( + ProviderOpenAI( + base_url = paste0(endpoint, "/openai/deployments/", deployment_id), + model = deployment_id, + api_key = "", + extra_args = extra_args + ), + api_version = api_version, + credentials = credentials + ) + }, properties = list( - api_key = prop_string(), - token = prop_string(allow_null = TRUE), - endpoint = prop_string(), + credentials = class_function | NULL, api_version = prop_string() ) ) # https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#api-key -azure_key <- function() { - key_get("AZURE_OPENAI_API_KEY") -} - azure_endpoint <- function() { key_get("AZURE_OPENAI_ENDPOINT") } @@ -89,10 +107,13 @@ method(chat_request, ProviderAzure) <- function(provider, req <- request(provider@base_url) req <- req_url_path_append(req, "/chat/completions") req <- req_url_query(req, `api-version` = provider@api_version) - req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key") - if (!is.null(provider@token)) { - req <- req_auth_bearer_token(req, provider@token) - } + # Note: could use req_headers_redacted() here but it requires a very new + # httr2 version. + req <- req_headers( + req, + !!!provider@credentials(), + .redact = c("api-key", "Authorization") + ) req <- req_retry(req, max_tries = 2) req <- req_error(req, body = function(resp) resp_body_json(resp)$message) @@ -127,3 +148,20 @@ method(chat_request, ProviderAzure) <- function(provider, req } + +default_azure_credentials <- function(api_key = NULL, token = NULL) { + api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY") + if (nchar(api_key)) { + return(function() list(`api-key` = api_key)) + } + + if (!is.null(token)) { + return(function() list(Authorization = paste("Bearer", token))) + } + + if (is_testing()) { + testthat::skip("no Azure credentials available") + } + + cli::cli_abort("No Azure credentials are available.") +} diff --git a/man/chat_azure.Rd b/man/chat_azure.Rd index 8e1b842..ce228f6 100644 --- a/man/chat_azure.Rd +++ b/man/chat_azure.Rd @@ -10,8 +10,9 @@ chat_azure( api_version = NULL, system_prompt = NULL, turns = NULL, - api_key = azure_key(), + api_key = NULL, token = NULL, + credentials = NULL, api_args = list(), echo = c("none", "text", "all") ) @@ -31,12 +32,18 @@ value of the \code{AZURE_OPENAI_ENDPOINT} envinronment variable.} previous conversation). If not provided, the conversation begins from scratch.} -\item{api_key}{The API key to use for authentication. You generally should -not supply this directly, but instead set the \code{AZURE_OPENAI_API_KEY} environment -variable.} +\item{api_key}{An API key to use for authentication. You generally should not +supply this directly, but instead set the \code{AZURE_OPENAI_API_KEY} +environment variable.} -\item{token}{Azure token for authentication. This is typically not required for -Azure OpenAI API calls, but can be used if your setup requires it.} +\item{token}{A literal Azure token to use for authentication.} + +\item{credentials}{A list of authentication headers to pass into +\code{\link[httr2:req_headers]{httr2::req_headers()}}, a function that returns them, or \code{NULL} to use +\code{token} or \code{api_key} to generate these headers instead. This is an escape +hatch that allows users to incorporate Azure credentials generated by other +packages into \pkg{ellmer}, or to manage the lifetime of credentials that +need to be refreshed.} \item{api_args}{Named list of arbitrary extra arguments appended to the body of every chat API call.} diff --git a/tests/testthat/_snaps/provider-azure.md b/tests/testthat/_snaps/provider-azure.md new file mode 100644 index 0000000..7a12aae --- /dev/null +++ b/tests/testthat/_snaps/provider-azure.md @@ -0,0 +1,32 @@ +# Azure request headers are generated correctly + + Code + req + Message + + POST + https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01 + Headers: + * api-key: '' + Body: json encoded data + Policies: + * retry_max_tries: 2 + * retry_on_failure: FALSE + * error_body: a function + +--- + + Code + req + Message + + POST + https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01 + Headers: + * Authorization: '' + Body: json encoded data + Policies: + * retry_max_tries: 2 + * retry_on_failure: FALSE + * error_body: a function + diff --git a/tests/testthat/test-provider-azure.R b/tests/testthat/test-provider-azure.R index 5998821..ed04b5a 100644 --- a/tests/testthat/test-provider-azure.R +++ b/tests/testthat/test-provider-azure.R @@ -12,3 +12,32 @@ test_that("can make simple request", { expect_match(resp, "2") expect_equal(chat$last_turn()@tokens, c(44, 1)) }) + +test_that("Azure request headers are generated correctly", { + turn <- Turn( + role = "user", + contents = list(ContentText("What is 1 + 1?")) + ) + endpoint <- "https://ai-hwickhamai260967855527.openai.azure.com" + deployment_id <- "gpt-4o-mini" + + # API key. + p <- ProviderAzure( + endpoint = endpoint, + deployment_id = deployment_id, + api_version = "2024-06-01", + credentials = default_azure_credentials("key") + ) + req <- chat_request(p, FALSE, list(turn)) + expect_snapshot(req) + + # Token. + p <- ProviderAzure( + endpoint = endpoint, + deployment_id = deployment_id, + api_version = "2024-06-01", + credentials = default_azure_credentials("", "token") + ) + req <- chat_request(p, FALSE, list(turn)) + expect_snapshot(req) +})