diff --git a/sdk/identity/azure_identity/Cargo.toml b/sdk/identity/azure_identity/Cargo.toml index d5fa67ea5d..9ca365d6eb 100644 --- a/sdk/identity/azure_identity/Cargo.toml +++ b/sdk/identity/azure_identity/Cargo.toml @@ -29,6 +29,9 @@ url.workspace = true [target.'cfg(unix)'.dependencies] tz-rs = { workspace = true, optional = true } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +async-process.workspace = true + [dev-dependencies] azure_core_test.workspace = true azure_security_keyvault_secrets = { path = "../../keyvault/azure_security_keyvault_secrets" } diff --git a/sdk/identity/azure_identity/examples/interactive_credentials.rs b/sdk/identity/azure_identity/examples/interactive_credentials.rs new file mode 100644 index 0000000000..ed91977d10 --- /dev/null +++ b/sdk/identity/azure_identity/examples/interactive_credentials.rs @@ -0,0 +1,40 @@ +use azure_identity::interactive_credential::interactive_browser_credential::InteractiveBrowserCredential; +use oauth2::TokenResponse; +use reqwest::Client; +use std::error::Error; +use url::Url; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let test_subscription_id = + std::env::var("AZURE_SUBSCRIPTION_ID").expect("AZURE_SUBSCRIPTION_ID required"); + let test_tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID required"); + + let _ = run_app_inter(test_subscription_id, test_tenant_id).await?; + Ok(()) +} + +async fn run_app_inter(subscription_id: String, tenant_id: String) -> Result<(), Box> { + let interactive_credentials = InteractiveBrowserCredential::new(None, Some(tenant_id), None)?; + + let token_response = interactive_credentials + .get_token(Some(&["https://management.azure.com/.default"])) + .await?; + + let access_token_secret = token_response.access_token().secret(); + + let url = Url::parse(&format!( + "https://management.azure.com/subscriptions/{subscription_id}/providers/Microsoft.Storage/storageAccounts?api-version=2019-06-01" + ))?; + + let resp = Client::new() + .get(url) + .header("Authorization", format!("Bearer {}", access_token_secret)) + .send() + .await? + .text() + .await?; + + println!("Res interactive: {resp}"); + Ok(()) +} diff --git a/sdk/identity/azure_identity/src/interactive_credential/interactive_browser_credential.rs b/sdk/identity/azure_identity/src/interactive_credential/interactive_browser_credential.rs new file mode 100644 index 0000000000..ff686501a9 --- /dev/null +++ b/sdk/identity/azure_identity/src/interactive_credential/interactive_browser_credential.rs @@ -0,0 +1,116 @@ +use super::internal_server::*; +use crate::authorization_code_flow; +use azure_core::{error::ErrorKind, http::new_http_client, http::Url, Error}; +use oauth2::{ + basic::BasicTokenType, AuthorizationCode, ClientId, EmptyExtraTokenFields, + StandardTokenResponse, +}; +use std::{str::FromStr, sync::Arc}; + +/// Default OAuth scopes used when none are provided. +#[allow(dead_code)] +const DEFAULT_SCOPE_ARR: [&str; 3] = ["openid", "offline_access", "profile"]; +/// Default client ID for interactive browser authentication. +#[allow(dead_code)] +const DEFAULT_DEVELOPER_SIGNON_CLIENT_ID: &str = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; +/// Default tenant ID used when none is specified. +#[allow(dead_code)] +const DEFAULT_ORGANIZATIONS_TENANT_ID: &str = "organizations"; + +/// Provides interactive browser-based authentication. +#[derive(Clone)] +pub struct InteractiveBrowserCredential { + /// Client ID of the application. + pub(crate) client_id: ClientId, + /// Tenant ID for the authentication request. + pub(crate) tenant_id: String, + /// Redirect URI where the authentication response is sent. + pub(crate) redirect_url: Url, +} + +impl InteractiveBrowserCredential { + /// Creates a new `InteractiveBrowserCredential` instance with optional parameters. + pub fn new( + client_id: Option, + tenant_id: Option, + redirect_url: Option, + ) -> azure_core::Result> { + let client_id = client_id + .unwrap_or_else(|| ClientId::new(DEFAULT_DEVELOPER_SIGNON_CLIENT_ID.to_owned())); + + let tenant_id = tenant_id.unwrap_or_else(|| DEFAULT_ORGANIZATIONS_TENANT_ID.to_owned()); + + let redirect_url = redirect_url.unwrap_or_else(|| { + Url::from_str(&format!("http://localhost:{}", LOCAL_SERVER_PORT)) + .expect("Failed to parse redirect URL") + }); + + Ok(Arc::new(Self { + client_id, + tenant_id, + redirect_url, + })) + } + + /// Starts the interactive browser authentication flow and returns an access token. + /// + /// If no scopes are provided, default scopes will be used. + #[allow(dead_code)] + pub async fn get_token( + &self, + scopes: Option<&[&str]>, + ) -> azure_core::Result> { + let scopes = scopes.unwrap_or(&DEFAULT_SCOPE_ARR); + + let authorization_code_flow = authorization_code_flow::authorize( + self.client_id.clone(), + None, + &self.tenant_id, + self.redirect_url.clone(), + scopes, + ); + + let auth_code = open_url(authorization_code_flow.authorize_url.as_ref()).await; + + match auth_code { + Some(code) => { + authorization_code_flow + .exchange(new_http_client(), AuthorizationCode::new(code)) + .await + } + None => Err(Error::message( + ErrorKind::Other, + "Failed to retrieve authorization code.", + )), + } + } +} +#[cfg(test)] +mod tests { + use super::*; + use tracing::debug; + use tracing::Level; + use tracing_subscriber; + static INIT: std::sync::Once = std::sync::Once::new(); + + fn init_tracing() { + INIT.call_once(|| { + tracing_subscriber::fmt() + .with_max_level(Level::DEBUG) + .init(); + }); + } + + #[tokio::test] + async fn interactive_auth_flow_should_return_token() { + init_tracing(); + debug!("Starting interactive authentication test"); + + let credential = InteractiveBrowserCredential::new(None, None, None) + .expect("Failed to create credential"); + + let token_response = credential.get_token(None).await; + debug!("Authentication result: {:#?}", token_response); + assert!(token_response.is_ok()); + } +} diff --git a/sdk/identity/azure_identity/src/interactive_credential/internal_server.rs b/sdk/identity/azure_identity/src/interactive_credential/internal_server.rs new file mode 100644 index 0000000000..dcc8ad08c1 --- /dev/null +++ b/sdk/identity/azure_identity/src/interactive_credential/internal_server.rs @@ -0,0 +1,203 @@ +use std::io::{self, BufRead, BufReader, Write}; +use std::net::{Shutdown, TcpListener, TcpStream}; +use std::time::Duration; +use tracing::{error, info}; + +///The port where the local server is listening on the auth_code +#[allow(dead_code)] +pub const LOCAL_SERVER_PORT: u16 = 47828; + +/// Opens the given URL in the default system browser and starts a local web server +/// to receive the authorization code. +#[allow(dead_code)] +#[cfg(target_os = "windows")] +pub async fn open_url(url: &str) -> Option { + use async_process::Command; + let spawned = Command::new("cmd").args(["/C", "explorer", url]).spawn(); + handle_browser_command(spawned) +} + +/// Opens the given URL in the default system browser and starts a local web server +/// to receive the authorization code. +#[allow(dead_code)] +#[cfg(target_os = "macos")] +pub async fn open_url(url: &str) -> Option { + use async_process::Command; + let spawned = Command::new("open").arg(url).spawn(); + handle_browser_command(spawned) +} + +/// Opens the given URL in the default system browser and starts a local web server +/// to receive the authorization code. +#[allow(dead_code)] +#[cfg(target_os = "linux")] +pub async fn open_url(url: &str) -> Option { + use async_process::Command; + + if let Some(command) = find_linux_browser_command().await { + let spawned = Command::new(command).arg(url).spawn(); + return handle_browser_command(spawned); + } + + info!("Open the following link manually in your browser: {url}"); + None +} + +/// Method to check if the command to open the link in a browser is available on the computer +/// exists. +#[allow(dead_code)] +#[cfg(target_os = "linux")] +async fn is_command_available(cmd: &str) -> bool { + use async_process::Command; + Command::new("which") + .arg(cmd) + .output() + .await + .map(|o| !o.stdout.is_empty()) + .unwrap_or(false) +} + +/// Method with all the commands which could open the browser to call the authorization url +/// If there is no command installed or available on the system, it returns a 'None' and the link +/// will be logged +#[allow(dead_code)] +#[cfg(target_os = "linux")] +async fn find_linux_browser_command() -> Option { + let candidates = [ + "xdg-open", + "gnome-open", + "kfmclient", + "microsoft-edge", + "wslview", + ]; + for cmd in candidates.iter() { + if is_command_available(cmd).await { + return Some(cmd.to_string()); + } + } + None +} + +/// starting the browser if the browser could be started, then the webserver should be started to +/// get the auth code +#[allow(dead_code)] +fn handle_browser_command(result: Result) -> Option { + match result { + Ok(_) => start_webserver(), + Err(e) => { + error!("Failed to start browser command: {e}"); + None + } + } +} + +/// Starts the webserver on the `http://localhost`. Returns None, if the server could not have +/// started +#[allow(dead_code)] +/// Starts a simple HTTP server on localhost to receive the auth code. +fn start_webserver() -> Option { + TcpListener::bind(("127.0.0.1", LOCAL_SERVER_PORT)) + .ok() + .and_then(handle_tcp_connection) +} + +fn handle_tcp_connection(listener: TcpListener) -> Option { + listener + .incoming() + .take(1) + .next()? + .ok() + .and_then(handle_client) +} +/// Main method to handle the incomming traffic. +/// After a 10s timeout the stream will be closed +/// if the stream could be opened, we read the whole request and try to extract the auth_code +/// Returns also the html code to show if it worked +#[allow(dead_code)] +fn handle_client(mut stream: TcpStream) -> Option { + stream + .set_read_timeout(Some(Duration::from_secs(10))) + .ok()?; + + let buf_reader = BufReader::new(&stream); + let mut request_lines = vec![]; + for line in buf_reader.lines().map_while(Result::ok) { + if line.is_empty() { + break; + } + request_lines.push(line); + } + + let request = request_lines.join("\n"); + + let auth_code = extract_auth_code(&request); + let response_body = r#" +Auth Complete +

Authentication complete. You may close this tab.

+"#; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + + stream.write_all(response.as_bytes()).ok()?; + stream.flush().ok()?; + stream.shutdown(Shutdown::Both).ok()?; + + auth_code +} + +/// Extracts the `code` query parameter from the request. +#[allow(dead_code)] +fn extract_auth_code(request: &str) -> Option { + let code_start = request.rfind("code=")? + 5; + let rest = &request[code_start..]; + let end = rest.find('&').unwrap_or(rest.len()); + Some(rest[..end].to_string()) +} + +#[cfg(test)] +mod test_internal_server { + use super::*; + use tracing::debug; + use tracing::Level; + use tracing_subscriber::FmtSubscriber; + fn init_logger() { + let subscriber = FmtSubscriber::builder() + .with_max_level(Level::DEBUG) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + } + + #[tokio::test] + async fn test_valid_command() { + init_logger(); + assert!(is_command_available("ls").await); + } + + #[tokio::test] + async fn test_invalid_command() { + init_logger(); + assert!(!is_command_available("non_existing_command_foo").await); + } + + #[test] + fn test_extract_code_param() { + let url = "GET /?code=abc123&state=xyz"; + assert_eq!(extract_auth_code(url).unwrap(), "abc123"); + } + + #[test] + fn test_extract_code_at_end() { + let url = "GET /?state=xyz&code=abc123"; + assert_eq!(extract_auth_code(url).unwrap(), "abc123"); + } + + #[test] + fn test_extract_code_missing() { + let url = "GET /?state=only"; + assert!(extract_auth_code(url).is_none()); + } +} diff --git a/sdk/identity/azure_identity/src/interactive_credential/mod.rs b/sdk/identity/azure_identity/src/interactive_credential/mod.rs new file mode 100644 index 0000000000..54d2039890 --- /dev/null +++ b/sdk/identity/azure_identity/src/interactive_credential/mod.rs @@ -0,0 +1,3 @@ +mod internal_server; + +pub mod interactive_browser_credential; diff --git a/sdk/identity/azure_identity/src/lib.rs b/sdk/identity/azure_identity/src/lib.rs index dbfac517cf..46431f87b5 100644 --- a/sdk/identity/azure_identity/src/lib.rs +++ b/sdk/identity/azure_identity/src/lib.rs @@ -10,6 +10,7 @@ mod chained_token_credential; mod credentials; mod env; mod federated_credentials_flow; +pub mod interactive_credential; mod oauth2_http_client; mod refresh_token; mod timeout;