Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a more flexible credentials mechanism for chat_azure() #248

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

* `chat_openai()` should be less likely to timeout when not streaming chat results (#213).

* `chat_azure()` now has a `credentials` argument to make authentication against
Azure more flexible (#248, @atheriel).

# ellmer 0.1.0

* New `chat_vllm()` to chat with models served by vLLM (#140).
Expand Down
86 changes: 62 additions & 24 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ NULL
#' value of the `AZURE_OPENAI_ENDPOINT` envinronment variable.
#' @param deployment_id Deployment id for the model you want to use.
#' @param api_version The API version to use.
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `AZURE_OPENAI_API_KEY` environment
#' variable.
#' @param token Azure token for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it.
#' @param api_key An API key to use for authentication. You generally should not
#' supply this directly, but instead set the `AZURE_OPENAI_API_KEY`
#' environment variable.
#' @param token A literal Azure token to use for authentication.
#' @param credentials A list of authentication headers to pass into
#' [`httr2::req_headers()`], a function that returns them, or `NULL` to use
#' `token` or `api_key` to generate these headers instead. This is an escape
#' hatch that allows users to incorporate Azure credentials generated by other
#' packages into \pkg{ellmer}, or to manage the lifetime of credentials that
#' need to be refreshed.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
Expand All @@ -34,46 +39,59 @@ chat_azure <- function(endpoint = azure_endpoint(),
api_version = NULL,
system_prompt = NULL,
turns = NULL,
api_key = azure_key(),
api_key = NULL,
token = NULL,
credentials = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
check_string(endpoint)
check_string(deployment_id)
api_version <- set_default(api_version, "2024-06-01")
turns <- normalize_turns(turns, system_prompt)
check_exclusive(api_key, token, credentials, .require = FALSE)
check_string(api_key, allow_null = TRUE)
check_string(token, allow_null = TRUE)
echo <- check_echo(echo)

base_url <- paste0(endpoint, "/openai/deployments/", deployment_id)
if (is_list(credentials)) {
static_credentials <- force(credentials)
credentials <- function() static_credentials
}
check_function(credentials, allow_null = TRUE)
credentials <- credentials %||% default_azure_credentials(api_key, token)

provider <- ProviderAzure(
base_url = base_url,
endpoint = endpoint,
model = deployment_id,
deployment_id = deployment_id,
api_version = api_version,
token = token,
extra_args = api_args,
api_key = api_key
credentials = credentials,
extra_args = api_args
)
Chat$new(provider = provider, turns = turns, echo = echo)
}

ProviderAzure <- new_class(
"ProviderAzure",
parent = ProviderOpenAI,
constructor = function(endpoint, deployment_id, api_version, credentials,
extra_args = list()) {
new_object(
ProviderOpenAI(
base_url = paste0(endpoint, "/openai/deployments/", deployment_id),
model = deployment_id,
api_key = "",
extra_args = extra_args
),
api_version = api_version,
credentials = credentials
)
},
properties = list(
api_key = prop_string(),
token = prop_string(allow_null = TRUE),
endpoint = prop_string(),
credentials = class_function | NULL,
api_version = prop_string()
)
)

# https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#api-key
azure_key <- function() {
key_get("AZURE_OPENAI_API_KEY")
}

azure_endpoint <- function() {
key_get("AZURE_OPENAI_ENDPOINT")
}
Expand All @@ -89,10 +107,13 @@ method(chat_request, ProviderAzure) <- function(provider,
req <- request(provider@base_url)
req <- req_url_path_append(req, "/chat/completions")
req <- req_url_query(req, `api-version` = provider@api_version)
req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key")
if (!is.null(provider@token)) {
req <- req_auth_bearer_token(req, provider@token)
}
# Note: could use req_headers_redacted() here but it requires a very new
# httr2 version.
req <- req_headers(
req,
!!!provider@credentials(),
.redact = c("api-key", "Authorization")
)
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

Expand Down Expand Up @@ -127,3 +148,20 @@ method(chat_request, ProviderAzure) <- function(provider,

req
}

default_azure_credentials <- function(api_key = NULL, token = NULL) {
api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY")
if (nchar(api_key)) {
return(function() list(`api-key` = api_key))
}

if (!is.null(token)) {
return(function() list(Authorization = paste("Bearer", token)))
}

if (is_testing()) {
testthat::skip("no Azure credentials available")
}

cli::cli_abort("No Azure credentials are available.")
}
19 changes: 13 additions & 6 deletions man/chat_azure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions tests/testthat/_snaps/provider-azure.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Azure request headers are generated correctly

Code
req
Message
<httr2_request>
POST
https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01
Headers:
* api-key: '<REDACTED>'
Body: json encoded data
Policies:
* retry_max_tries: 2
* retry_on_failure: FALSE
* error_body: a function

---

Code
req
Message
<httr2_request>
POST
https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01
Headers:
* Authorization: '<REDACTED>'
Body: json encoded data
Policies:
* retry_max_tries: 2
* retry_on_failure: FALSE
* error_body: a function

29 changes: 29 additions & 0 deletions tests/testthat/test-provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,32 @@ test_that("can make simple request", {
expect_match(resp, "2")
expect_equal(chat$last_turn()@tokens, c(44, 1))
})

test_that("Azure request headers are generated correctly", {
turn <- Turn(
role = "user",
contents = list(ContentText("What is 1 + 1?"))
)
endpoint <- "https://ai-hwickhamai260967855527.openai.azure.com"
deployment_id <- "gpt-4o-mini"

# API key.
p <- ProviderAzure(
endpoint = endpoint,
deployment_id = deployment_id,
api_version = "2024-06-01",
credentials = default_azure_credentials("key")
)
req <- chat_request(p, FALSE, list(turn))
expect_snapshot(req)

# Token.
p <- ProviderAzure(
endpoint = endpoint,
deployment_id = deployment_id,
api_version = "2024-06-01",
credentials = default_azure_credentials("", "token")
)
req <- chat_request(p, FALSE, list(turn))
expect_snapshot(req)
})
Loading