diff --git a/.github/workflows/wasm.yml b/.github/workflows/wasm.yml index 08420c7f..405baf57 100644 --- a/.github/workflows/wasm.yml +++ b/.github/workflows/wasm.yml @@ -20,6 +20,10 @@ jobs: - uses: actions/checkout@v4 with: path: crates + - name: Install Rust 1.75.0 + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: 1.75.0 # We use a synthetic crate to ensure no dev-dependencies are enabled, which can # be incompatible with some of these targets. - name: Create synthetic crate for testing diff --git a/Cargo.lock b/Cargo.lock index 780389bc..4dae6761 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -278,8 +278,10 @@ dependencies = [ "atrium-xrpc", "base64", "chrono", + "dashmap", "ecdsa", "elliptic-curve", + "futures", "hickory-resolver", "jose-jwa", "jose-jwk", @@ -1995,9 +1997,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" -version = "0.10.70" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -2027,9 +2029,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.105" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", diff --git a/atrium-api/src/agent/atp_agent.rs b/atrium-api/src/agent/atp_agent.rs index 3d29f3ea..486c849d 100644 --- a/atrium-api/src/agent/atp_agent.rs +++ b/atrium-api/src/agent/atp_agent.rs @@ -324,6 +324,7 @@ mod tests { &self, request: Request<Vec<u8>>, ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> { + // tick tokio time #[cfg(not(target_arch = "wasm32"))] tokio::time::sleep(std::time::Duration::from_micros(10)).await; diff --git a/atrium-api/src/agent/store.rs b/atrium-api/src/agent/store.rs deleted file mode 100644 index 22bdcb37..00000000 --- a/atrium-api/src/agent/store.rs +++ /dev/null @@ -1,16 +0,0 @@ -mod memory; - -use std::future::Future; - -pub use self::memory::MemorySessionStore; -pub(crate) use super::Session; - -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub trait SessionStore { - #[must_use] - fn get_session(&self) -> impl Future<Output = Option<Session>>; - #[must_use] - fn set_session(&self, session: Session) -> impl Future<Output = ()>; - #[must_use] - fn clear_session(&self) -> impl Future<Output = ()>; -} diff --git a/atrium-api/src/agent/store/memory.rs b/atrium-api/src/agent/store/memory.rs deleted file mode 100644 index 05eedaaf..00000000 --- a/atrium-api/src/agent/store/memory.rs +++ /dev/null @@ -1,20 +0,0 @@ -use super::{Session, SessionStore}; -use std::sync::Arc; -use tokio::sync::RwLock; - -#[derive(Default, Clone)] -pub struct MemorySessionStore { - session: Arc<RwLock<Option<Session>>>, -} - -impl SessionStore for MemorySessionStore { - async fn get_session(&self) -> Option<Session> { - self.session.read().await.clone() - } - async fn set_session(&self, session: Session) { - self.session.write().await.replace(session); - } - async fn clear_session(&self) { - self.session.write().await.take(); - } -} diff --git a/atrium-common/src/store/memory.rs b/atrium-common/src/store/memory.rs index dc81fd7c..b792bf4d 100644 --- a/atrium-common/src/store/memory.rs +++ b/atrium-common/src/store/memory.rs @@ -2,14 +2,14 @@ use super::Store; use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use thiserror::Error; +use tokio::sync::Mutex; #[derive(Error, Debug)] #[error("memory store error")] pub struct Error; -// TODO: LRU cache? #[derive(Clone)] pub struct MemoryStore<K, V> { store: Arc<Mutex<HashMap<K, V>>>, @@ -29,18 +29,18 @@ where type Error = Error; async fn get(&self, key: &K) -> Result<Option<V>, Self::Error> { - Ok(self.store.lock().unwrap().get(key).cloned()) + Ok(self.store.lock().await.get(key).cloned()) } async fn set(&self, key: K, value: V) -> Result<(), Self::Error> { - self.store.lock().unwrap().insert(key, value); + self.store.lock().await.insert(key, value); Ok(()) } async fn del(&self, key: &K) -> Result<(), Self::Error> { - self.store.lock().unwrap().remove(key); + self.store.lock().await.remove(key); Ok(()) } async fn clear(&self) -> Result<(), Self::Error> { - self.store.lock().unwrap().clear(); + self.store.lock().await.clear(); Ok(()) } } diff --git a/atrium-identity/src/did.rs b/atrium-identity/src/did.rs index 0b731cb1..9e873904 100644 --- a/atrium-identity/src/did.rs +++ b/atrium-identity/src/did.rs @@ -2,10 +2,9 @@ mod common_resolver; mod plc_resolver; mod web_resolver; -use crate::Error; - pub use self::common_resolver::{CommonDidResolver, CommonDidResolverConfig}; pub use self::plc_resolver::DEFAULT_PLC_DIRECTORY_URL; +use crate::Error; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; use atrium_common::resolver::Resolver; diff --git a/atrium-identity/src/identity_resolver.rs b/atrium-identity/src/identity_resolver.rs index a70e1856..2c27e825 100644 --- a/atrium-identity/src/identity_resolver.rs +++ b/atrium-identity/src/identity_resolver.rs @@ -1,6 +1,8 @@ use crate::error::{Error, Result}; -use crate::{did::DidResolver, handle::HandleResolver}; -use atrium_api::types::string::AtIdentifier; +use atrium_api::{ + did_doc::DidDocument, + types::string::{AtIdentifier, Did, Handle}, +}; use atrium_common::resolver::Resolver; use serde::{Deserialize, Serialize}; @@ -29,8 +31,8 @@ impl<D, H> IdentityResolver<D, H> { impl<D, H> Resolver for IdentityResolver<D, H> where - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync, { type Input = str; type Output = ResolvedIdentity; diff --git a/atrium-oauth/oauth-client/Cargo.toml b/atrium-oauth/oauth-client/Cargo.toml index 8920ccfc..f698f217 100644 --- a/atrium-oauth/oauth-client/Cargo.toml +++ b/atrium-oauth/oauth-client/Cargo.toml @@ -14,12 +14,13 @@ keywords = ["atproto", "bluesky", "oauth"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -atrium-api = { workspace = true, default-features = false } +atrium-api = { workspace = true, features = ["agent"] } atrium-common.workspace = true atrium-identity.workspace = true atrium-xrpc.workspace = true base64.workspace = true chrono.workspace = true +dashmap.workspace = true ecdsa = { workspace = true, features = ["signing"] } elliptic-curve.workspace = true jose-jwa.workspace = true @@ -32,11 +33,14 @@ serde_html_form.workspace = true serde_json.workspace = true sha2.workspace = true thiserror.workspace = true +tokio = { workspace = true, features = ["sync"] } trait-variant.workspace = true [dev-dependencies] +atrium-api = { workspace = true, features = ["bluesky"] } +futures.workspace = true hickory-resolver.workspace = true -p256 = { workspace = true, features = ["pem"] } +p256 = { workspace = true, features = ["pem", "std"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [features] diff --git a/atrium-oauth/oauth-client/examples/generate_key.rs b/atrium-oauth/oauth-client/examples/generate_key.rs new file mode 100644 index 00000000..08228e5c --- /dev/null +++ b/atrium-oauth/oauth-client/examples/generate_key.rs @@ -0,0 +1,28 @@ +use elliptic_curve::pkcs8::EncodePrivateKey; +use elliptic_curve::SecretKey; +use jose_jwa::{Algorithm, Signing}; +use jose_jwk::{Class, Jwk, JwkSet, Key, Parameters}; +use p256::NistP256; +use rand::rngs::ThreadRng; + +fn main() -> Result<(), Box<dyn std::error::Error>> { + let secret_key = SecretKey::<NistP256>::random(&mut ThreadRng::default()); + let key = Key::from(&secret_key.public_key().into()); + let jwks = JwkSet { + keys: vec![Jwk { + key, + prm: Parameters { + alg: Some(Algorithm::Signing(Signing::Es256)), + kid: Some(String::from("kid01")), + cls: Some(Class::Signing), + ..Default::default() + }, + }], + }; + println!("SECRET KEY:"); + println!("{}", secret_key.to_pkcs8_pem(Default::default())?.as_str()); + + println!("JWKS:"); + println!("{}", serde_json::to_string_pretty(&jwks)?); + Ok(()) +} diff --git a/atrium-oauth/oauth-client/examples/main.rs b/atrium-oauth/oauth-client/examples/main.rs index ee211fc4..9ee766e4 100644 --- a/atrium-oauth/oauth-client/examples/main.rs +++ b/atrium-oauth/oauth-client/examples/main.rs @@ -1,5 +1,7 @@ +use atrium_api::agent::Agent; use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}; +use atrium_oauth_client::store::session::MemorySessionStore; use atrium_oauth_client::store::state::MemoryStateStore; use atrium_oauth_client::{ AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient, @@ -57,6 +59,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { protected_resource_metadata: Default::default(), }, state_store: MemoryStateStore::default(), + session_store: MemorySessionStore::default(), }; let client = OAuthClient::new(config)?; println!( @@ -76,7 +79,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { ); // Click the URL and sign in, - // then copy and paste the URL like “http://127.0.0.1/?iss=...&code=...” after it is redirected. + // then copy and paste the URL like “http://127.0.0.1/callback?iss=...&code=...” after it is redirected. print!("Redirected url: "); stdout().lock().flush()?; @@ -85,7 +88,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { let uri = url.trim().parse::<Uri>()?; let params = serde_html_form::from_str(uri.query().unwrap())?; - println!("{}", serde_json::to_string_pretty(&client.callback(params).await?)?); - + let (session, _) = client.callback(params).await?; + let agent = Agent::new(session); + let output = agent + .api + .app + .bsky + .feed + .get_timeline( + atrium_api::app::bsky::feed::get_timeline::ParametersData { + algorithm: None, + cursor: None, + limit: 3.try_into().ok(), + } + .into(), + ) + .await?; + for feed in &output.feed { + println!("{feed:?}"); + } Ok(()) } diff --git a/atrium-oauth/oauth-client/src/error.rs b/atrium-oauth/oauth-client/src/error.rs index 16f87001..082b6227 100644 --- a/atrium-oauth/oauth-client/src/error.rs +++ b/atrium-oauth/oauth-client/src/error.rs @@ -5,17 +5,25 @@ pub enum Error { #[error(transparent)] ClientMetadata(#[from] crate::atproto::Error), #[error(transparent)] - Keyset(#[from] crate::keyset::Error), + Dpop(#[from] crate::http_client::dpop::Error), #[error(transparent)] - Identity(#[from] atrium_identity::Error), + Keyset(#[from] crate::keyset::Error), #[error(transparent)] ServerAgent(#[from] crate::server_agent::Error), + #[error(transparent)] + OAuthSession(#[from] crate::oauth_session::Error), + #[error(transparent)] + SessionRegistry(#[from] crate::store::session_registry::Error), + #[error(transparent)] + Identity(#[from] atrium_identity::Error), #[error("authorize error: {0}")] Authorize(String), #[error("callback error: {0}")] Callback(String), - #[error("state store error: {0:?}")] + #[error("state store error: {0}")] StateStore(Box<dyn std::error::Error + Send + Sync + 'static>), + #[error("session store error: {0}")] + SessionStore(Box<dyn std::error::Error + Send + Sync + 'static>), } pub type Result<T> = core::result::Result<T, Error>; diff --git a/atrium-oauth/oauth-client/src/http_client/dpop.rs b/atrium-oauth/oauth-client/src/http_client/dpop.rs index b92fd621..15570b10 100644 --- a/atrium-oauth/oauth-client/src/http_client/dpop.rs +++ b/atrium-oauth/oauth-client/src/http_client/dpop.rs @@ -1,17 +1,21 @@ -use crate::jose::create_signed_jwt; -use crate::jose::jws::RegisteredHeader; -use crate::jose::jwt::{Claims, PublicClaims, RegisteredClaims}; -use crate::store::memory::MemorySimpleStore; -use crate::store::SimpleStore; -use atrium_xrpc::http::{Request, Response}; -use atrium_xrpc::HttpClient; -use base64::engine::general_purpose::URL_SAFE_NO_PAD; -use base64::Engine; +use crate::jose::{ + create_signed_jwt, + jws::RegisteredHeader, + jwt::{Claims, PublicClaims, RegisteredClaims}, +}; +use atrium_common::store::{memory::MemoryStore, Store}; +use atrium_xrpc::{ + http::{Request, Response}, + HttpClient, +}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use chrono::Utc; use jose_jwa::{Algorithm, Signing}; use jose_jwk::{crypto, EcCurves, Jwk, Key}; -use rand::rngs::SmallRng; -use rand::{RngCore, SeedableRng}; +use rand::{ + rngs::SmallRng, + {RngCore, SeedableRng}, +}; use serde::Deserialize; use sha2::{Digest, Sha256}; use std::sync::Arc; @@ -36,9 +40,9 @@ pub enum Error { type Result<T> = core::result::Result<T, Error>; -pub struct DpopClient<T, S = MemorySimpleStore<String, String>> +pub struct DpopClient<T, S = MemoryStore<String, String>> where - S: SimpleStore<String, String>, + S: Store<String, String>, { inner: Arc<T>, pub(crate) key: Key, @@ -65,14 +69,14 @@ impl<T> DpopClient<T> { return Err(Error::UnsupportedKey); } } - let nonces = MemorySimpleStore::<String, String>::default(); + let nonces = MemoryStore::<String, String>::default(); Ok(Self { inner: http_client, key, nonces, is_auth_server }) } } impl<T, S> DpopClient<T, S> where - S: SimpleStore<String, String>, + S: Store<String, String>, { fn build_proof( &self, @@ -135,7 +139,8 @@ where impl<T, S> HttpClient for DpopClient<T, S> where T: HttpClient + Send + Sync + 'static, - S: SimpleStore<String, String> + Send + Sync + 'static, + S: Store<String, String> + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { async fn send_http( &self, @@ -150,7 +155,7 @@ where let ath = request .headers() .get("Authorization") - .filter(|v| v.to_str().map_or(false, |s| s.starts_with("DPoP "))) + .filter(|v| v.to_str().is_ok_and(|s| s.starts_with("DPoP "))) .map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..]))); let init_nonce = self.nonces.get(&nonce_key).await?; @@ -182,3 +187,14 @@ where Ok(response) } } + +impl<T> Clone for DpopClient<T> { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + key: self.key.clone(), + nonces: self.nonces.clone(), + is_auth_server: self.is_auth_server, + } + } +} diff --git a/atrium-oauth/oauth-client/src/keyset.rs b/atrium-oauth/oauth-client/src/keyset.rs index b6728f9e..c5aee2bf 100644 --- a/atrium-oauth/oauth-client/src/keyset.rs +++ b/atrium-oauth/oauth-client/src/keyset.rs @@ -57,7 +57,7 @@ impl Keyset { .0 .iter() .filter_map(|key| { - if key.prm.cls.map_or(false, |c| c != cls) { + if key.prm.cls.is_some_and(|c| c != cls) { return None; } let alg = match &key.key { diff --git a/atrium-oauth/oauth-client/src/lib.rs b/atrium-oauth/oauth-client/src/lib.rs index 06071dc7..cacd7b7e 100644 --- a/atrium-oauth/oauth-client/src/lib.rs +++ b/atrium-oauth/oauth-client/src/lib.rs @@ -5,6 +5,7 @@ mod http_client; mod jose; mod keyset; mod oauth_client; +mod oauth_session; mod resolver; mod server_agent; pub mod store; @@ -19,7 +20,116 @@ pub use error::{Error, Result}; pub use http_client::default::DefaultHttpClient; pub use http_client::dpop::DpopClient; pub use oauth_client::{OAuthClient, OAuthClientConfig}; +pub use oauth_session::OAuthSession; pub use resolver::OAuthResolverConfig; pub use types::{ AuthorizeOptionPrompt, AuthorizeOptions, CallbackParams, OAuthClientMetadata, TokenSet, }; + +#[cfg(test)] +mod tests { + use crate::{ + resolver::OAuthResolver, + types::{ + OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthProtectedResourceMetadata, + TryIntoOAuthClientMetadata, + }, + AtprotoLocalhostClientMetadata, OAuthResolverConfig, + }; + use atrium_api::{ + did_doc::{DidDocument, Service}, + types::string::{Did, Handle}, + }; + use atrium_common::resolver::Resolver; + use atrium_xrpc::HttpClient; + use jose_jwk::Key; + use std::sync::Arc; + + pub struct MockDidResolver; + + impl Resolver for MockDidResolver { + type Input = Did; + type Output = DidDocument; + type Error = atrium_identity::Error; + async fn resolve(&self, did: &Self::Input) -> Result<Self::Output, Self::Error> { + Ok(DidDocument { + context: None, + id: did.as_ref().to_string(), + also_known_as: None, + verification_method: None, + service: Some(vec![Service { + id: String::from("#atproto_pds"), + r#type: String::from("AtprotoPersonalDataServer"), + service_endpoint: String::from("https://aud.example.com"), + }]), + }) + } + } + + pub struct NoopHandleResolver; + + impl Resolver for NoopHandleResolver { + type Input = Handle; + type Output = Did; + type Error = atrium_identity::Error; + async fn resolve(&self, _: &Self::Input) -> Result<Self::Output, Self::Error> { + unimplemented!() + } + } + + pub fn oauth_resolver<T>( + http_client: Arc<T>, + ) -> OAuthResolver<T, MockDidResolver, NoopHandleResolver> + where + T: HttpClient + Send + Sync, + { + OAuthResolver::new( + OAuthResolverConfig { + did_resolver: MockDidResolver, + handle_resolver: NoopHandleResolver, + authorization_server_metadata: Default::default(), + protected_resource_metadata: Default::default(), + }, + http_client, + ) + } + + pub fn dpop_key() -> Key { + serde_json::from_str( + r#"{ + "kty": "EC", + "crv": "P-256", + "x": "NIRNgPVAwnVNzN5g2Ik2IMghWcjnBOGo9B-lKXSSXFs", + "y": "iWF-Of43XoSTZxcadO9KWdPTjiCoviSztYw7aMtZZMc", + "d": "9MuCYfKK4hf95p_VRj6cxKJwORTgvEU3vynfmSgFH2M" + }"#, + ) + .expect("key should be valid") + } + + pub fn server_metadata() -> OAuthAuthorizationServerMetadata { + OAuthAuthorizationServerMetadata { + issuer: String::from("https://iss.example.com"), + token_endpoint: String::from("https://iss.example.com/token"), + token_endpoint_auth_methods_supported: Some(vec![ + String::from("none"), + String::from("private_key_jwt"), + ]), + ..Default::default() + } + } + + pub fn client_metadata() -> OAuthClientMetadata { + AtprotoLocalhostClientMetadata::default() + .try_into_client_metadata(&None) + .expect("client metadata should be valid") + } + + pub fn protected_resource_metadata() -> OAuthProtectedResourceMetadata { + OAuthProtectedResourceMetadata { + resource: String::from("https://aud.example.com"), + authorization_servers: Some(vec![String::from("https://iss.example.com")]), + ..Default::default() + } + } +} diff --git a/atrium-oauth/oauth-client/src/oauth_client.rs b/atrium-oauth/oauth-client/src/oauth_client.rs index e844f00a..a5721239 100644 --- a/atrium-oauth/oauth-client/src/oauth_client.rs +++ b/atrium-oauth/oauth-client/src/oauth_client.rs @@ -1,29 +1,37 @@ -use crate::constants::FALLBACK_ALG; -use crate::error::{Error, Result}; -use crate::keyset::Keyset; -use crate::resolver::{OAuthResolver, OAuthResolverConfig}; -use crate::server_agent::{OAuthRequest, OAuthServerAgent}; -use crate::store::state::{InternalStateData, StateStore}; -use crate::types::{ - AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions, CallbackParams, - OAuthAuthorizationServerMetadata, OAuthClientMetadata, - OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, TokenSet, - TryIntoOAuthClientMetadata, +use crate::{ + constants::FALLBACK_ALG, + error::{Error, Result}, + keyset::Keyset, + oauth_session::OAuthSession, + resolver::{OAuthResolver, OAuthResolverConfig}, + server_agent::{OAuthRequest, OAuthServerAgent, OAuthServerFactory}, + store::{ + session::{Session, SessionStore}, + session_registry::SessionRegistry, + state::{InternalStateData, StateStore}, + }, + types::{ + AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions, + CallbackParams, OAuthAuthorizationServerMetadata, OAuthClientMetadata, + OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, + TryIntoOAuthClientMetadata, + }, + utils::{compare_algos, generate_key, generate_nonce}, +}; +use atrium_api::{ + did_doc::DidDocument, + types::string::{Did, Handle}, }; -use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values}; use atrium_common::resolver::Resolver; -use atrium_identity::{did::DidResolver, handle::HandleResolver}; use atrium_xrpc::HttpClient; -use base64::engine::general_purpose::URL_SAFE_NO_PAD; -use base64::Engine; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use jose_jwk::{Jwk, JwkSet, Key}; -use rand::rngs::ThreadRng; use serde::Serialize; use sha2::{Digest, Sha256}; use std::sync::Arc; #[cfg(feature = "default-client")] -pub struct OAuthClientConfig<S, M, D, H> +pub struct OAuthClientConfig<S0, S1, M, D, H> where M: TryIntoOAuthClientMetadata, { @@ -31,13 +39,14 @@ where pub client_metadata: M, pub keys: Option<Vec<Jwk>>, // Stores - pub state_store: S, + pub state_store: S0, + pub session_store: S1, // Services pub resolver: OAuthResolverConfig<D, H>, } #[cfg(not(feature = "default-client"))] -pub struct OAuthClientConfig<S, T, M, D, H> +pub struct OAuthClientConfig<S0, S1, T, M, D, H> where M: TryIntoOAuthClientMetadata, { @@ -45,7 +54,8 @@ where pub client_metadata: M, pub keys: Option<Vec<Jwk>>, // Stores - pub state_store: S, + pub state_store: S0, + pub session_store: S1, // Services pub resolver: OAuthResolverConfig<D, H>, // Others @@ -53,86 +63,122 @@ where } #[cfg(feature = "default-client")] -pub struct OAuthClient<S, D, H, T = crate::http_client::default::DefaultHttpClient> +pub struct OAuthClient<S0, S1, D, H, T = crate::http_client::default::DefaultHttpClient> where - S: StateStore, T: HttpClient + Send + Sync + 'static, + S1: SessionStore + Send + Sync + 'static, + S1::Error: std::error::Error + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, keyset: Option<Keyset>, resolver: Arc<OAuthResolver<T, D, H>>, - state_store: S, + server_factory: Arc<OAuthServerFactory<T, D, H>>, + state_store: S0, + session_registry: Arc<SessionRegistry<S1, T, D, H>>, http_client: Arc<T>, } #[cfg(not(feature = "default-client"))] -pub struct OAuthClient<S, D, H, T> +pub struct OAuthClient<S0, S1, D, H, T> where - S: StateStore, T: HttpClient + Send + Sync + 'static, + S1: SessionStore + Send + Sync + 'static, + S1::Error: std::error::Error + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, keyset: Option<Keyset>, resolver: Arc<OAuthResolver<T, D, H>>, - state_store: S, + server_factory: Arc<OAuthServerFactory<T, D, H>>, + state_store: S0, + session_registry: Arc<SessionRegistry<S1, T, D, H>>, http_client: Arc<T>, } #[cfg(feature = "default-client")] -impl<S, D, H> OAuthClient<S, D, H, crate::http_client::default::DefaultHttpClient> +impl<S0, S1, D, H> OAuthClient<S0, S1, D, H, crate::http_client::default::DefaultHttpClient> where - S: StateStore, + S1: SessionStore + Send + Sync + 'static, + S1::Error: std::error::Error + Send + Sync + 'static, { - pub fn new<M>(config: OAuthClientConfig<S, M, D, H>) -> Result<Self> + pub fn new<M>(config: OAuthClientConfig<S0, S1, M, D, H>) -> Result<Self> where M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>, { let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None }; let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?; let http_client = Arc::new(crate::http_client::default::DefaultHttpClient::default()); + let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client))); + let server_factory = Arc::new(OAuthServerFactory::new( + client_metadata.clone(), + Arc::clone(&resolver), + Arc::clone(&http_client), + keyset.clone(), + )); + let session_registry = + Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory))); Ok(Self { client_metadata, keyset, - resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())), + resolver, + server_factory, state_store: config.state_store, + session_registry, http_client, }) } } #[cfg(not(feature = "default-client"))] -impl<S, D, H, T> OAuthClient<S, D, H, T> +impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T> where - S: StateStore, T: HttpClient + Send + Sync + 'static, + S1: SessionStore + Send + Sync + 'static, + S1::Error: std::error::Error + Send + Sync + 'static, { - pub fn new<M>(config: OAuthClientConfig<S, T, M, D, H>) -> Result<Self> + pub fn new<M>(config: OAuthClientConfig<S0, S1, T, M, D, H>) -> Result<Self> where M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>, { let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None }; let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?; let http_client = Arc::new(config.http_client); + let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client))); + let server_factory = Arc::new(OAuthServerFactory::new( + client_metadata.clone(), + Arc::clone(&resolver), + Arc::clone(&http_client), + keyset.clone(), + )); + let session_registry = + Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory))); Ok(Self { client_metadata, keyset, - resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())), + resolver, + server_factory, state_store: config.state_store, + session_registry, http_client, }) } } -impl<S, D, H, T> OAuthClient<S, D, H, T> +impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T> where - S: StateStore, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + S0: StateStore + Send + Sync + 'static, + S1: SessionStore + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, T: HttpClient + Send + Sync + 'static, + S0::Error: std::error::Error + Send + Sync + 'static, + S1::Error: std::error::Error + Send + Sync + 'static, { pub fn jwks(&self) -> JwkSet { self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default() } + /// Start the authorization process. + /// + /// This method will return a URL that the user should be redirected to. pub async fn authorize( &self, input: impl AsRef<str>, @@ -156,6 +202,7 @@ where iss: metadata.issuer.clone(), dpop_key: dpop_key.clone(), verifier, + app_state: options.state, }; self.state_store .set(state.clone(), state_data) @@ -174,14 +221,7 @@ where prompt: options.prompt.map(String::from), }; if metadata.pushed_authorization_request_endpoint.is_some() { - let server = OAuthServerAgent::new( - dpop_key, - metadata.clone(), - self.client_metadata.clone(), - self.resolver.clone(), - self.http_client.clone(), - self.keyset.clone(), - )?; + let server = self.server_factory.build_from_metadata(dpop_key, metadata.clone())?; let par_response = server .request::<OAuthPusehedAuthorizationRequestResponse>( OAuthRequest::PushedAuthorizationRequest(parameters), @@ -208,7 +248,14 @@ where todo!() } } - pub async fn callback(&self, params: CallbackParams) -> Result<TokenSet> { + /// Handle the callback from the authorization server. + /// + /// This method will exchange the authorization code for an access token and store the session, + /// and return the [`OAuthSession`] and the application state. + pub async fn callback( + &self, + params: CallbackParams, + ) -> Result<(OAuthSession<T, D, H, S1>, Option<String>)> { let Some(state_key) = params.state else { return Err(Error::Callback("missing `state` parameter".into())); }; @@ -233,18 +280,55 @@ where } else if metadata.authorization_response_iss_parameter_supported == Some(true) { return Err(Error::Callback("missing `iss` parameter".into())); } - let server = OAuthServerAgent::new( - state.dpop_key.clone(), - metadata.clone(), - self.client_metadata.clone(), - self.resolver.clone(), - self.http_client.clone(), - self.keyset.clone(), - )?; - let token_set = server.exchange_code(¶ms.code, &state.verifier).await?; - - // TODO: create session? - Ok(token_set) + let server = + self.server_factory.build_from_metadata(state.dpop_key.clone(), metadata.clone())?; + match server.exchange_code(¶ms.code, &state.verifier).await { + Ok(token_set) => { + let sub = token_set.sub.clone(); + self.session_registry + .set(sub.clone(), Session { dpop_key: state.dpop_key.clone(), token_set }) + .await + .map_err(|e| Error::SessionStore(Box::new(e)))?; + Ok((self.create_session(server, &sub).await?, state.app_state)) + } + Err(_) => { + todo!() + } + } + } + /// Load a stored session by giving the subject DID. + /// + /// This method will return the [`OAuthSession`] if it exists. + pub async fn restore(&self, sub: &Did) -> Result<OAuthSession<T, D, H, S1>> { + // let session_handle = self.session_registry.get(sub).await?; + // let session = session_handle.read().await.session(); + let session = self.session_registry.get(sub, false).await?; + self.create_session( + self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?, + sub, + ) + .await + } + /// Revoke a session by giving the subject DID. + pub async fn revoke(&self, sub: &Did) -> Result<()> { + let session = self.session_registry.get(sub, false).await?; + let server_agent = + self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?; + server_agent.revoke(&session.token_set.access_token).await?; + self.session_registry.del(sub).await.map_err(|e| Error::SessionStore(Box::new(e))) + } + async fn create_session( + &self, + server: OAuthServerAgent<T, D, H>, + sub: &Did, + ) -> Result<OAuthSession<T, D, H, S1>> { + Ok(OAuthSession::new( + server.server_metadata.clone(), + sub.clone(), + Arc::clone(&self.http_client), + Arc::clone(&self.session_registry), + ) + .await?) } fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option<Key> { let mut algs = @@ -254,8 +338,7 @@ where } fn generate_pkce() -> (String, String) { // https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 - let verifier = - URL_SAFE_NO_PAD.encode(get_random_values::<_, 32>(&mut ThreadRng::default())); + let verifier = [generate_nonce(), generate_nonce()].join(""); (URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier) } } diff --git a/atrium-oauth/oauth-client/src/oauth_session.rs b/atrium-oauth/oauth-client/src/oauth_session.rs new file mode 100644 index 00000000..60cb80b4 --- /dev/null +++ b/atrium-oauth/oauth-client/src/oauth_session.rs @@ -0,0 +1,664 @@ +mod inner; +mod store; + +use self::store::MemorySessionStore; +use crate::{ + http_client::dpop::DpopClient, + store::{session::SessionStore, session_registry::SessionRegistry}, + types::OAuthAuthorizationServerMetadata, +}; +use atrium_api::{ + agent::{utils::SessionWithEndpointStore, CloneWithProxy, Configure, SessionManager}, + did_doc::DidDocument, + types::string::{Did, Handle}, +}; +use atrium_common::resolver::Resolver; +use atrium_xrpc::{ + http::{Request, Response}, + HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, +}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{fmt::Debug, sync::Arc}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + Dpop(#[from] crate::http_client::dpop::Error), + #[error(transparent)] + SessionRegistry(#[from] crate::store::session_registry::Error), + #[error(transparent)] + Store(#[from] atrium_common::store::memory::Error), +} + +pub struct OAuthSession<T, D, H, S> +where + T: HttpClient + Send + Sync + 'static, + S: SessionStore + Send + Sync + 'static, +{ + store: Arc<SessionWithEndpointStore<store::MemorySessionStore, String>>, + inner: inner::Client<S, T, D, H>, + sub: Did, + session_registry: Arc<SessionRegistry<S, T, D, H>>, +} + +impl<T, D, H, S> OAuthSession<T, D, H, S> +where + T: HttpClient + Send + Sync, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, + S: SessionStore + Send + Sync + 'static, +{ + pub(crate) async fn new( + server_metadata: OAuthAuthorizationServerMetadata, + sub: Did, + http_client: Arc<T>, + session_registry: Arc<SessionRegistry<S, T, D, H>>, + ) -> Result<Self, Error> { + // initialize SessionWithEndpointStore + let (dpop_key, token_set) = { + let s = session_registry.get(&sub, false).await?; + (s.dpop_key.clone(), s.token_set.clone()) + }; + let store = Arc::new(SessionWithEndpointStore::new( + MemorySessionStore::default(), + token_set.aud.clone(), + )); + store.set(token_set.access_token.clone()).await?; + // initialize inner client + let inner = inner::Client::new( + Arc::clone(&store), + DpopClient::new( + dpop_key, + http_client, + false, + &server_metadata.token_endpoint_auth_signing_alg_values_supported, + )?, + sub.clone(), + Arc::clone(&session_registry), + ); + Ok(Self { store, inner, sub, session_registry }) + } +} + +impl<T, D, H, S> HttpClient for OAuthSession<T, D, H, S> +where + T: HttpClient + Send + Sync + 'static, + D: Send + Sync, + H: Send + Sync, + S: SessionStore + Send + Sync, +{ + async fn send_http( + &self, + request: Request<Vec<u8>>, + ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> { + self.inner.send_http(request).await + } +} + +impl<T, D, H, S> XrpcClient for OAuthSession<T, D, H, S> +where + T: HttpClient + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, + S: SessionStore + Send + Sync + 'static, +{ + fn base_uri(&self) -> String { + self.inner.base_uri() + } + async fn send_xrpc<P, I, O, E>( + &self, + request: &XrpcRequest<P, I>, + ) -> Result<OutputDataOrBytes<O>, atrium_xrpc::Error<E>> + where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + self.inner.send_xrpc(request).await + } +} + +impl<T, D, H, S> SessionManager for OAuthSession<T, D, H, S> +where + T: HttpClient + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, + S: SessionStore + Send + Sync + 'static, +{ + async fn did(&self) -> Option<Did> { + Some(self.sub.clone()) + } +} + +impl<T, D, H, S> Configure for OAuthSession<T, D, H, S> +where + T: HttpClient + Send + Sync, + S: SessionStore + Send + Sync, +{ + fn configure_endpoint(&self, endpoint: String) { + self.inner.configure_endpoint(endpoint); + } + fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) { + self.inner.configure_labelers_header(labeler_dids); + } + fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) { + self.inner.configure_proxy_header(did, service_type); + } +} + +impl<T, D, H, S> CloneWithProxy for OAuthSession<T, D, H, S> +where + T: HttpClient + Send + Sync, + S: SessionStore + Send + Sync, +{ + fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self { + Self { + store: self.store.clone(), + inner: self.inner.clone_with_proxy(did, service_type), + sub: self.sub.clone(), + session_registry: Arc::clone(&self.session_registry), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::server_agent::OAuthServerFactory; + use crate::tests::{ + client_metadata, dpop_key, oauth_resolver, protected_resource_metadata, server_metadata, + MockDidResolver, NoopHandleResolver, + }; + use crate::{ + jose::jwt::Claims, + store::session::Session, + types::{OAuthTokenResponse, OAuthTokenType, RefreshRequestParameters, TokenSet}, + }; + use atrium_api::{ + agent::{Agent, AtprotoServiceType}, + client::Service, + xrpc::http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, StatusCode}, + }; + use atrium_common::store::Store; + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; + use std::{collections::HashMap, time::Duration}; + use tokio::sync::Mutex; + + #[derive(Default)] + struct RecordData { + host: Option<String>, + headers: HeaderMap<HeaderValue>, + } + + struct MockHttpClient { + data: Arc<Mutex<Option<RecordData>>>, + next_token: Arc<Mutex<Option<OAuthTokenResponse>>>, + } + + impl MockHttpClient { + fn new(data: Arc<Mutex<Option<RecordData>>>) -> Self { + Self { + data, + next_token: Arc::new(Mutex::new(Some(OAuthTokenResponse { + access_token: String::from("new_accesstoken"), + token_type: OAuthTokenType::DPoP, + expires_in: Some(10), + refresh_token: Some(String::from("new_refreshtoken")), + scope: None, + sub: None, + }))), + } + } + } + + impl HttpClient for MockHttpClient { + async fn send_http( + &self, + request: Request<Vec<u8>>, + ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> { + // tick tokio time + tokio::time::sleep(std::time::Duration::from_micros(0)).await; + + match (request.uri().host(), request.uri().path()) { + (Some("iss.example.com"), "/.well-known/oauth-authorization-server") => { + return Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&server_metadata())?) + .map_err(|e| e.into()); + } + (Some("aud.example.com"), "/.well-known/oauth-protected-resource") => { + return Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&protected_resource_metadata())?) + .map_err(|e| e.into()); + } + _ => {} + } + + let mut headers = request.headers().clone(); + let Some(authorization) = headers + .remove("authorization") + .and_then(|value| value.to_str().map(String::from).ok()) + else { + let response = if request.uri().path() == "/token" { + let parameters = + serde_html_form::from_bytes::<RefreshRequestParameters>(request.body())?; + let token_response = if parameters.refresh_token == "refreshtoken" { + self.next_token.lock().await.take() + } else { + None + }; + if let Some(token_response) = token_response { + Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&token_response)?)? + } else { + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", "DPoP error=\"invalid_token\"") + .body(Vec::new())? + } + } else { + Response::builder().status(StatusCode::UNAUTHORIZED).body(Vec::new())? + }; + return Ok(response); + }; + let Some(token) = authorization.strip_prefix("DPoP ") else { + panic!("authorization header should start with DPoP"); + }; + if token == "expired" { + return Ok(Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", "DPoP error=\"invalid_token\"") + .body(Vec::new())?); + } + let dpop_jwt = headers.remove("dpop").expect("dpop header should be present"); + let payload = dpop_jwt + .to_str() + .expect("dpop header should be valid") + .split('.') + .nth(1) + .expect("dpop header should have 2 parts"); + let claims = URL_SAFE_NO_PAD + .decode(payload) + .ok() + .and_then(|value| serde_json::from_slice::<Claims>(&value).ok()) + .expect("dpop payload should be valid"); + assert!(claims.registered.iat.is_some()); + assert!(claims.registered.jti.is_some()); + assert_eq!(claims.public.htm, Some(request.method().to_string())); + assert_eq!(claims.public.htu, Some(request.uri().to_string())); + + self.data + .lock() + .await + .replace(RecordData { host: request.uri().host().map(String::from), headers }); + let output = atrium_api::com::atproto::server::get_service_auth::OutputData { + token: String::from("fake_token"), + }; + Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&output)?) + .map_err(|e| e.into()) + } + } + + struct MockSessionStore { + data: Arc<Mutex<HashMap<Did, Session>>>, + } + + impl Store<Did, Session> for MockSessionStore { + type Error = Error; + + async fn get(&self, key: &Did) -> Result<Option<Session>, Self::Error> { + tokio::time::sleep(Duration::from_micros(10)).await; + Ok(self.data.lock().await.get(key).cloned()) + } + async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> { + tokio::time::sleep(Duration::from_micros(10)).await; + self.data.lock().await.insert(key, value); + Ok(()) + } + async fn del(&self, _: &Did) -> Result<(), Self::Error> { + unimplemented!() + } + async fn clear(&self) -> Result<(), Self::Error> { + unimplemented!() + } + } + + impl SessionStore for MockSessionStore {} + + fn did() -> Did { + Did::new(String::from("did:fake:sub.test")).expect("did should be valid") + } + + fn default_store() -> Arc<Mutex<HashMap<Did, Session>>> { + let did = did(); + let token_set = TokenSet { + iss: String::from("https://iss.example.com"), + sub: did.clone(), + aud: String::from("https://aud.example.com"), + scope: None, + refresh_token: Some(String::from("refreshtoken")), + access_token: String::from("accesstoken"), + token_type: OAuthTokenType::DPoP, + expires_at: None, + }; + let dpop_key = dpop_key(); + let session = Session { token_set, dpop_key }; + Arc::new(Mutex::new(HashMap::from_iter([(did, session)]))) + } + + async fn oauth_session( + data: Arc<Mutex<Option<RecordData>>>, + store: Arc<Mutex<HashMap<Did, Session>>>, + ) -> OAuthSession<MockHttpClient, MockDidResolver, NoopHandleResolver, MockSessionStore> { + let http_client = Arc::new(MockHttpClient::new(data)); + let resolver = Arc::new(oauth_resolver(Arc::clone(&http_client))); + let server_factory = Arc::new(OAuthServerFactory::new( + client_metadata(), + resolver, + Arc::clone(&http_client), + None, + )); + let session_registory = Arc::new(SessionRegistry::new( + MockSessionStore { data: Arc::clone(&store) }, + server_factory, + )); + OAuthSession::new(server_metadata(), did(), http_client, session_registory) + .await + .expect("failed to create oauth session") + } + + async fn oauth_agent( + data: Arc<Mutex<Option<RecordData>>>, + ) -> Agent<impl SessionManager + Configure + CloneWithProxy> { + Agent::new(oauth_session(data, default_store()).await) + } + + async fn call_service( + service: &Service<impl SessionManager + Sync>, + ) -> Result<(), atrium_xrpc::Error<atrium_api::com::atproto::server::get_service_auth::Error>> + { + let output = service + .com + .atproto + .server + .get_service_auth( + atrium_api::com::atproto::server::get_service_auth::ParametersData { + aud: Did::new(String::from("did:fake:handle.test")) + .expect("did should be valid"), + exp: None, + lxm: None, + } + .into(), + ) + .await?; + assert_eq!(output.token, "fake_token"); + Ok(()) + } + + #[tokio::test] + async fn test_new() -> Result<(), Box<dyn std::error::Error>> { + let agent = oauth_agent(Default::default()).await; + assert_eq!(agent.did().await.as_deref(), Some("did:fake:sub.test")); + Ok(()) + } + + #[tokio::test] + async fn test_configure_endpoint() -> Result<(), Box<dyn std::error::Error>> { + let data = Default::default(); + let agent = oauth_agent(Arc::clone(&data)).await; + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").host.as_deref(), + Some("aud.example.com") + ); + agent.configure_endpoint(String::from("https://pds.example.com")); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").host.as_deref(), + Some("pds.example.com") + ); + Ok(()) + } + + #[tokio::test] + async fn test_configure_labelers_header() -> Result<(), Box<dyn std::error::Error>> { + let data = Default::default(); + let agent = oauth_agent(Arc::clone(&data)).await; + // not configured + { + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::new() + ); + } + // configured 1 + { + agent.configure_labelers_header(Some(vec![( + Did::new(String::from("did:fake:labeler.test"))?, + false, + )])); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-accept-labelers"), + HeaderValue::from_static("did:fake:labeler.test"), + )]) + ); + } + // configured 2 + { + agent.configure_labelers_header(Some(vec![ + (Did::new(String::from("did:fake:labeler.test_redact"))?, true), + (Did::new(String::from("did:fake:labeler.test"))?, false), + ])); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-accept-labelers"), + HeaderValue::from_static( + "did:fake:labeler.test_redact;redact, did:fake:labeler.test" + ), + )]) + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_configure_proxy_header() -> Result<(), Box<dyn std::error::Error>> { + let data = Arc::new(Mutex::new(Default::default())); + let agent = oauth_agent(Arc::clone(&data)).await; + // not configured + { + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::new() + ); + } + // labeler service + { + agent.configure_proxy_header( + Did::new(String::from("did:fake:service.test"))?, + AtprotoServiceType::AtprotoLabeler, + ); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#atproto_labeler"), + )]) + ); + } + // custom service + { + agent.configure_proxy_header( + Did::new(String::from("did:fake:service.test"))?, + "custom_service", + ); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#custom_service"), + )]) + ); + } + // api_with_proxy + { + call_service( + &agent.api_with_proxy( + Did::new(String::from("did:fake:service.test"))?, + "temp_service", + ), + ) + .await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#temp_service"), + )]) + ); + call_service(&agent.api).await?; + assert_eq!( + data.lock().await.as_ref().expect("data should be recorded").headers, + HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:fake:service.test#custom_service"), + )]) + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_xrpc_without_token() -> Result<(), Box<dyn std::error::Error>> { + let oauth_session = oauth_session(Default::default(), default_store()).await; + oauth_session.store.clear().await?; + let agent = Agent::new(oauth_session); + let result = agent + .api + .com + .atproto + .server + .get_service_auth( + atrium_api::com::atproto::server::get_service_auth::ParametersData { + aud: Did::new(String::from("did:fake:handle.test")) + .expect("did should be valid"), + exp: None, + lxm: None, + } + .into(), + ) + .await; + match result.expect_err("should fail without token") { + atrium_xrpc::Error::XrpcResponse(err) => { + assert_eq!(err.status, StatusCode::UNAUTHORIZED); + } + _ => panic!("unexpected error"), + } + Ok(()) + } + + #[tokio::test] + async fn test_xrpc_with_refresh() -> Result<(), Box<dyn std::error::Error>> { + let session_data = default_store(); + if let Some(session) = session_data.lock().await.get_mut(&did()) { + session.token_set.access_token = String::from("expired"); + } + let oauth_session = oauth_session(Default::default(), Arc::clone(&session_data)).await; + let agent = Agent::new(oauth_session); + let result = agent + .api + .com + .atproto + .server + .get_service_auth( + atrium_api::com::atproto::server::get_service_auth::ParametersData { + aud: Did::new(String::from("did:fake:handle.test")) + .expect("did should be valid"), + exp: None, + lxm: None, + } + .into(), + ) + .await; + match result { + Ok(output) => { + assert_eq!(output.token, "fake_token"); + } + Err(err) => { + panic!("unexpected error: {err:?}"); + } + } + // wait for async update + tokio::time::sleep(Duration::from_micros(0)).await; + { + let token_set = session_data + .lock() + .await + .get(&did()) + .expect("session should be present") + .token_set + .clone(); + assert_eq!(token_set.access_token, "new_accesstoken"); + assert_eq!(token_set.refresh_token, Some(String::from("new_refreshtoken"))); + } + Ok(()) + } + + #[tokio::test] + async fn test_xrpc_with_duplicated_refresh() -> Result<(), Box<dyn std::error::Error>> { + let session_data = default_store(); + if let Some(session) = session_data.lock().await.get_mut(&did()) { + session.token_set.access_token = String::from("expired"); + } + let oauth_session = oauth_session(Default::default(), session_data).await; + let agent = Arc::new(Agent::new(oauth_session)); + + let handles = (0..3).map(|_| { + let agent = Arc::clone(&agent); + tokio::spawn(async move { + agent + .api + .com + .atproto + .server + .get_service_auth( + atrium_api::com::atproto::server::get_service_auth::ParametersData { + aud: Did::new(String::from("did:fake:handle.test")) + .expect("did should be valid"), + exp: None, + lxm: None, + } + .into(), + ) + .await + }) + }); + for result in futures::future::join_all(handles).await { + match result? { + Ok(output) => { + assert_eq!(output.token, "fake_token"); + } + Err(err) => { + panic!("unexpected error: {err:?}"); + } + } + } + Ok(()) + } +} diff --git a/atrium-oauth/oauth-client/src/oauth_session/inner.rs b/atrium-oauth/oauth-client/src/oauth_session/inner.rs new file mode 100644 index 00000000..98fa27f5 --- /dev/null +++ b/atrium-oauth/oauth-client/src/oauth_session/inner.rs @@ -0,0 +1,154 @@ +use super::store::MemorySessionStore; +use crate::{ + store::{session::SessionStore, session_registry::SessionRegistry}, + DpopClient, +}; +use atrium_api::{ + agent::{ + utils::{SessionClient, SessionWithEndpointStore}, + CloneWithProxy, Configure, + }, + did_doc::DidDocument, + types::string::{Did, Handle}, +}; +use atrium_common::resolver::Resolver; +use atrium_xrpc::{ + http::{Request, Response}, + Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, +}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{fmt::Debug, sync::Arc}; + +pub struct Client<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, +{ + inner: SessionClient<MemorySessionStore, DpopClient<T>, String>, + store: Arc<SessionWithEndpointStore<MemorySessionStore, String>>, + sub: Did, + session_registry: Arc<SessionRegistry<S, T, D, H>>, +} + +impl<S, T, D, H> Client<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, +{ + pub fn new( + store: Arc<SessionWithEndpointStore<MemorySessionStore, String>>, + xrpc: DpopClient<T>, + sub: Did, + session_registry: Arc<SessionRegistry<S, T, D, H>>, + ) -> Self { + let inner = SessionClient::new(Arc::clone(&store), xrpc); + Self { inner, store, sub, session_registry } + } +} + +impl<S, T, D, H> Client<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, +{ + // https://datatracker.ietf.org/doc/html/rfc6750#section-3 + fn is_invalid_token_response<O, E>(result: &Result<OutputDataOrBytes<O>, Error<E>>) -> bool + where + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + match result { + Err(Error::Authentication(value)) => value + .to_str() + .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), + _ => false, + } + } + async fn refresh_token(&self) { + if let Ok(session) = self.session_registry.get(&self.sub, true).await { + let _ = self.store.set(session.token_set.access_token.clone()).await; + } + } +} + +impl<S, T, D, H> HttpClient for Client<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, + D: Send + Sync, + H: Send + Sync, +{ + async fn send_http( + &self, + request: Request<Vec<u8>>, + ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> { + self.inner.send_http(request).await + } +} + +impl<S, T, D, H> XrpcClient for Client<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, +{ + fn base_uri(&self) -> String { + self.inner.base_uri() + } + async fn send_xrpc<P, I, O, E>( + &self, + request: &XrpcRequest<P, I>, + ) -> Result<OutputDataOrBytes<O>, Error<E>> + where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + let result = self.inner.send_xrpc(request).await; + // handle session-refreshes as needed + if Self::is_invalid_token_response(&result) { + self.refresh_token().await; + self.inner.send_xrpc(request).await + } else { + result + } + } +} + +impl<S, T, D, H> Configure for Client<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, +{ + fn configure_endpoint(&self, endpoint: String) { + self.inner.configure_endpoint(endpoint) + } + /// Configures the moderation services to be applied on requests. + fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) { + self.inner.configure_labelers_header(labeler_dids) + } + /// Configures the atproto-proxy header to be applied on requests. + fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) { + self.inner.configure_proxy_header(did, service_type) + } +} + +impl<S, T, D, H> CloneWithProxy for Client<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, + SessionClient<S, T, String>: CloneWithProxy, +{ + fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self { + Self { + inner: self.inner.clone_with_proxy(did, service_type), + store: Arc::clone(&self.store), + sub: self.sub.clone(), + session_registry: Arc::clone(&self.session_registry), + } + } +} diff --git a/atrium-oauth/oauth-client/src/oauth_session/store.rs b/atrium-oauth/oauth-client/src/oauth_session/store.rs new file mode 100644 index 00000000..cfcb822a --- /dev/null +++ b/atrium-oauth/oauth-client/src/oauth_session/store.rs @@ -0,0 +1,29 @@ +use atrium_api::agent::AuthorizationProvider; +use atrium_common::store::{self, memory::MemoryStore, Store}; +use atrium_xrpc::types::AuthorizationToken; + +#[derive(Default)] +pub struct MemorySessionStore(MemoryStore<(), String>); + +impl Store<(), String> for MemorySessionStore { + type Error = store::memory::Error; + + async fn get(&self, key: &()) -> Result<Option<String>, Self::Error> { + self.0.get(key).await + } + async fn set(&self, key: (), value: String) -> Result<(), Self::Error> { + self.0.set(key, value).await + } + async fn del(&self, key: &()) -> Result<(), Self::Error> { + self.0.del(key).await + } + async fn clear(&self) -> Result<(), Self::Error> { + self.0.clear().await + } +} + +impl AuthorizationProvider for MemorySessionStore { + async fn authorization_token(&self, _: bool) -> Option<AuthorizationToken> { + self.0.get(&()).await.ok().flatten().map(AuthorizationToken::Dpop) + } +} diff --git a/atrium-oauth/oauth-client/src/resolver.rs b/atrium-oauth/oauth-client/src/resolver.rs index d75f7abe..81e8f2ad 100644 --- a/atrium-oauth/oauth-client/src/resolver.rs +++ b/atrium-oauth/oauth-client/src/resolver.rs @@ -1,26 +1,29 @@ -use atrium_common::resolver::CachedResolver; -use atrium_common::resolver::Resolver; -use atrium_common::resolver::ThrottledResolver; -use atrium_common::types::cached::r#impl::Cache; -use atrium_common::types::cached::r#impl::CacheImpl; -use atrium_common::types::cached::CacheConfig; -use atrium_common::types::cached::Cacheable; -use atrium_common::types::throttled::Throttleable; mod oauth_authorization_server_resolver; mod oauth_protected_resource_resolver; use self::oauth_authorization_server_resolver::DefaultOAuthAuthorizationServerResolver; use self::oauth_protected_resource_resolver::DefaultOAuthProtectedResourceResolver; use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; -use atrium_identity::identity_resolver::{ - IdentityResolver, IdentityResolverConfig, ResolvedIdentity, +use atrium_api::{ + did_doc::DidDocument, + types::string::{Did, Handle}, +}; +use atrium_common::{ + resolver::{CachedResolver, Resolver, ThrottledResolver}, + types::{ + cached::{ + r#impl::{Cache, CacheImpl}, + {CacheConfig, Cacheable}, + }, + throttled::Throttleable, + }, +}; +use atrium_identity::{ + identity_resolver::{IdentityResolver, IdentityResolverConfig, ResolvedIdentity}, + {Error, Result}, }; -use atrium_identity::{did::DidResolver, handle::HandleResolver}; -use atrium_identity::{Error, Result}; use atrium_xrpc::HttpClient; -use std::marker::PhantomData; -use std::sync::Arc; -use std::time::Duration; +use std::{marker::PhantomData, sync::Arc, time::Duration}; #[derive(Clone, Debug)] pub struct OAuthAuthorizationServerMetadataResolverConfig { @@ -106,8 +109,8 @@ where impl<T, D, H> OAuthResolver<T, D, H> where T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync, { pub async fn get_authorization_server_metadata( &self, @@ -185,8 +188,8 @@ where impl<T, D, H> Resolver for OAuthResolver<T, D, H> where T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync, { type Input = str; type Output = (OAuthAuthorizationServerMetadata, Option<ResolvedIdentity>); diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index c9d556f3..695c3178 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -1,18 +1,28 @@ -use crate::constants::FALLBACK_ALG; -use crate::http_client::dpop::DpopClient; -use crate::jose::jwt::{RegisteredClaims, RegisteredClaimsAud}; -use crate::keyset::Keyset; -use crate::resolver::OAuthResolver; -use crate::types::{ - OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse, - PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, - TokenRequestParameters, TokenSet, +mod factory; + +pub use self::factory::OAuthServerFactory; +use crate::{ + constants::FALLBACK_ALG, + http_client::dpop::DpopClient, + jose::jwt::{RegisteredClaims, RegisteredClaimsAud}, + keyset::Keyset, + resolver::OAuthResolver, + types::{ + OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse, + PushedAuthorizationRequestParameters, RefreshRequestParameters, + RevocationRequestParameters, TokenGrantType, TokenRequestParameters, TokenSet, + }, + utils::{compare_algos, generate_nonce}, +}; +use atrium_api::{ + did_doc::DidDocument, + types::string::{Datetime, Did, Handle}, +}; +use atrium_common::resolver::Resolver; +use atrium_xrpc::{ + http::{Method, Request, StatusCode}, + HttpClient, }; -use crate::utils::{compare_algos, generate_nonce}; -use atrium_api::types::string::Datetime; -use atrium_identity::{did::DidResolver, handle::HandleResolver}; -use atrium_xrpc::http::{Method, Request, StatusCode}; -use atrium_xrpc::HttpClient; use chrono::{TimeDelta, Utc}; use jose_jwk::Key; use serde::Serialize; @@ -32,8 +42,14 @@ pub enum Error { Token(String), #[error("unsupported authentication method")] UnsupportedAuthMethod, + #[error("no refresh token available")] + TokenRefresh, + #[error("failed to parse DID: {0}")] + InvalidDid(&'static str), #[error(transparent)] DpopClient(#[from] crate::http_client::dpop::Error), + // #[error(transparent)] + // OAuthSession(#[from] crate::oauth_session::Error), #[error(transparent)] Http(#[from] atrium_xrpc::http::Error), #[error("http client error: {0}")] @@ -58,7 +74,7 @@ pub type Result<T> = core::result::Result<T, Error>; pub enum OAuthRequest { Token(TokenRequestParameters), Refresh(RefreshRequestParameters), - Revocation, + Revocation(RevocationRequestParameters), Introspection, PushedAuthorizationRequest(PushedAuthorizationRequestParameters), } @@ -68,7 +84,7 @@ impl OAuthRequest { String::from(match self { Self::Token(_) => "token", Self::Refresh(_) => "refresh", - Self::Revocation => "revocation", + Self::Revocation(_) => "revocation", Self::Introspection => "introspection", Self::PushedAuthorizationRequest(_) => "pushed_authorization_request", }) @@ -77,6 +93,8 @@ impl OAuthRequest { match self { Self::Token(_) | Self::Refresh(_) => StatusCode::OK, Self::PushedAuthorizationRequest(_) => StatusCode::CREATED, + // Unlike https://datatracker.ietf.org/doc/html/rfc7009#section-2.2, oauth-provider seems to return `204`. + Self::Revocation(_) => StatusCode::NO_CONTENT, _ => unimplemented!(), } } @@ -100,7 +118,7 @@ pub struct OAuthServerAgent<T, D, H> where T: HttpClient + Send + Sync + 'static, { - server_metadata: OAuthAuthorizationServerMetadata, + pub(crate) server_metadata: OAuthAuthorizationServerMetadata, client_metadata: OAuthClientMetadata, dpop_client: DpopClient<T>, resolver: Arc<OAuthResolver<T, D, H>>, @@ -110,8 +128,8 @@ where impl<T, D, H> OAuthServerAgent<T, D, H> where T: HttpClient + Send + Sync + 'static, - D: DidResolver + Send + Sync + 'static, - H: HandleResolver + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, { pub fn new( dpop_key: Key, @@ -129,23 +147,32 @@ where )?; Ok(Self { server_metadata, client_metadata, dpop_client, resolver, keyset }) } - /** - * VERY IMPORTANT ! Always call this to process token responses. - * - * Whenever an OAuth token response is received, we **MUST** verify that the - * "sub" is a DID, whose issuer authority is indeed the server we just - * obtained credentials from. This check is a critical step to actually be - * able to use the "sub" (DID) as being the actual user's identifier. - */ - async fn verify_token_response(&self, token_response: OAuthTokenResponse) -> Result<TokenSet> { - // ATPROTO requires that the "sub" is always present in the token response. - let Some(sub) = &token_response.sub else { + pub async fn revoke(&self, token: &str) -> Result<()> { + self.request::<()>(OAuthRequest::Revocation(RevocationRequestParameters { + token: token.into(), + })) + .await?; + Ok(()) + } + pub async fn exchange_code(&self, code: &str, verifier: &str) -> Result<TokenSet> { + let token_response = self + .request::<OAuthTokenResponse>(OAuthRequest::Token(TokenRequestParameters { + grant_type: TokenGrantType::AuthorizationCode, + code: code.into(), + redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ? + code_verifier: verifier.into(), + })) + .await?; + let Some(sub) = token_response.sub else { return Err(Error::Token("missing `sub` in token response".into())); }; - let (metadata, identity) = self.resolver.resolve_from_identity(sub).await?; - if metadata.issuer != self.server_metadata.issuer { - return Err(Error::Token("issuer mismatch".into())); - } + let sub = sub.parse().map_err(Error::InvalidDid)?; + // /!\ IMPORTANT /!\ + // + // The token_response MUST always be valid before the "sub" it contains + // can be trusted (see Atproto's OAuth spec for details). + let aud = self.verify_issuer(&sub).await?; + let expires_at = token_response.expires_in.and_then(|expires_in| { Datetime::now() .as_ref() @@ -153,9 +180,9 @@ where .map(Datetime::new) }); Ok(TokenSet { - sub: sub.clone(), - aud: identity.pds, - iss: metadata.issuer, + iss: self.server_metadata.issuer.clone(), + sub, + aud, scope: token_response.scope, access_token: token_response.access_token, refresh_token: token_response.refresh_token, @@ -163,17 +190,62 @@ where expires_at, }) } - pub async fn exchange_code(&self, code: &str, verifier: &str) -> Result<TokenSet> { - self.verify_token_response( - self.request(OAuthRequest::Token(TokenRequestParameters { - grant_type: TokenGrantType::AuthorizationCode, - code: code.into(), - redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ? - code_verifier: verifier.into(), + pub async fn refresh(&self, token_set: &TokenSet) -> Result<TokenSet> { + let Some(refresh_token) = token_set.refresh_token.as_ref() else { + return Err(Error::TokenRefresh); + }; + + // /!\ IMPORTANT /!\ + // + // The "sub" MUST be a DID, whose issuer authority is indeed the server we + // are trying to obtain credentials from. Note that we are doing this + // *before* we actually try to refresh the token: + // 1) To avoid unnecessary refresh + // 2) So that the refresh is the last async operation, ensuring as few + // async operations happen before the result gets a chance to be stored. + let aud = self.verify_issuer(&token_set.sub).await?; + + let response = self + .request::<OAuthTokenResponse>(OAuthRequest::Refresh(RefreshRequestParameters { + grant_type: TokenGrantType::RefreshToken, + refresh_token: refresh_token.clone(), + scope: None, })) - .await?, - ) - .await + .await?; + + let expires_at = response.expires_in.and_then(|expires_in| { + Datetime::now() + .as_ref() + .checked_add_signed(TimeDelta::seconds(expires_in)) + .map(Datetime::new) + }); + Ok(TokenSet { + iss: self.server_metadata.issuer.clone(), + sub: token_set.sub.clone(), + aud, + scope: response.scope, + access_token: response.access_token, + refresh_token: response.refresh_token, + token_type: response.token_type, + expires_at, + }) + } + /** + * VERY IMPORTANT ! Always call this to process token responses. + * + * Whenever an OAuth token response is received, we **MUST** verify that the + * "sub" is a DID, whose issuer authority is indeed the server we just + * obtained credentials from. This check is a critical step to actually be + * able to use the "sub" (DID) as being the actual user's identifier. + * + * @returns The user's PDS URL (the resource server for the user) + */ + async fn verify_issuer(&self, sub: &Did) -> Result<String> { + let (metadata, identity) = self.resolver.resolve_from_identity(sub).await?; + if metadata.issuer != self.server_metadata.issuer { + return Err(Error::Token("issuer mismatch".into())); + } + Ok(identity.pds) } pub async fn request<O>(&self, request: OAuthRequest) -> Result<O> where @@ -185,6 +257,7 @@ where let body = match &request { OAuthRequest::Token(params) => self.build_body(params)?, OAuthRequest::Refresh(params) => self.build_body(params)?, + OAuthRequest::Revocation(params) => self.build_body(params)?, OAuthRequest::PushedAuthorizationRequest(params) => self.build_body(params)?, _ => unimplemented!(), }; @@ -195,7 +268,13 @@ where .body(body.into_bytes())?; let res = self.dpop_client.send_http(req).await.map_err(Error::HttpClient)?; if res.status() == request.expected_status() { - Ok(serde_json::from_slice(res.body())?) + let body = res.body(); + if body.is_empty() { + // since an empty body cannot be deserialized, use “null” temporarily to allow deserialization to `()`. + Ok(serde_json::from_slice(b"null")?) + } else { + Ok(serde_json::from_slice(body)?) + } } else if res.status().is_client_error() { Err(Error::HttpStatusWithBody(res.status(), serde_json::from_slice(res.body())?)) } else { @@ -221,7 +300,7 @@ where Some("private_key_jwt") if method_supported .as_ref() - .map_or(false, |v| v.contains(&String::from("private_key_jwt"))) => + .is_some_and(|v| v.contains(&String::from("private_key_jwt"))) => { if let Some(keyset) = &self.keyset { let mut algs = self @@ -258,9 +337,7 @@ where } } Some("none") - if method_supported - .as_ref() - .map_or(false, |v| v.contains(&String::from("none"))) => + if method_supported.as_ref().is_some_and(|v| v.contains(&String::from("none"))) => { return Ok((None, None)) } @@ -273,7 +350,7 @@ where OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => { Some(&self.server_metadata.token_endpoint) } - OAuthRequest::Revocation => self.server_metadata.revocation_endpoint.as_ref(), + OAuthRequest::Revocation(_) => self.server_metadata.revocation_endpoint.as_ref(), OAuthRequest::Introspection => self.server_metadata.introspection_endpoint.as_ref(), OAuthRequest::PushedAuthorizationRequest(_) => { self.server_metadata.pushed_authorization_request_endpoint.as_ref() @@ -281,3 +358,18 @@ where } } } + +impl<T, D, H> Clone for OAuthServerAgent<T, D, H> +where + T: HttpClient + Send + Sync + 'static, +{ + fn clone(&self) -> Self { + Self { + server_metadata: self.server_metadata.clone(), + client_metadata: self.client_metadata.clone(), + dpop_client: self.dpop_client.clone(), + resolver: Arc::clone(&self.resolver), + keyset: self.keyset.clone(), + } + } +} diff --git a/atrium-oauth/oauth-client/src/server_agent/factory.rs b/atrium-oauth/oauth-client/src/server_agent/factory.rs new file mode 100644 index 00000000..8cc55a45 --- /dev/null +++ b/atrium-oauth/oauth-client/src/server_agent/factory.rs @@ -0,0 +1,69 @@ +use super::{OAuthServerAgent, Result}; +use crate::{ + keyset::Keyset, + resolver::OAuthResolver, + types::{OAuthAuthorizationServerMetadata, OAuthClientMetadata}, +}; +use atrium_api::{ + did_doc::DidDocument, + types::string::{Did, Handle}, +}; +use atrium_common::resolver::Resolver; +use atrium_identity::Error; +use atrium_xrpc::HttpClient; +use jose_jwk::Key; +use std::sync::Arc; + +pub struct OAuthServerFactory<T, D, H> +where + T: HttpClient + Send + Sync + 'static, +{ + client_metadata: OAuthClientMetadata, + resolver: Arc<OAuthResolver<T, D, H>>, + http_client: Arc<T>, + keyset: Option<Keyset>, +} + +impl<T, D, H> OAuthServerFactory<T, D, H> +where + T: HttpClient + Send + Sync + 'static, +{ + pub fn new( + client_metadata: OAuthClientMetadata, + resolver: Arc<OAuthResolver<T, D, H>>, + http_client: Arc<T>, + keyset: Option<Keyset>, + ) -> Self { + OAuthServerFactory { client_metadata, resolver, http_client, keyset } + } +} + +impl<T, D, H> OAuthServerFactory<T, D, H> +where + T: HttpClient + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync, +{ + pub async fn build_from_issuer( + &self, + dpop_key: Key, + issuer: impl AsRef<str>, + ) -> Result<OAuthServerAgent<T, D, H>> { + let server_metadata = self.resolver.get_authorization_server_metadata(&issuer).await?; + self.build_from_metadata(dpop_key, server_metadata) + } + pub fn build_from_metadata( + &self, + dpop_key: Key, + server_metadata: OAuthAuthorizationServerMetadata, + ) -> Result<OAuthServerAgent<T, D, H>> { + OAuthServerAgent::new( + dpop_key, + server_metadata, + self.client_metadata.clone(), + Arc::clone(&self.resolver), + Arc::clone(&self.http_client), + self.keyset.clone(), + ) + } +} diff --git a/atrium-oauth/oauth-client/src/store.rs b/atrium-oauth/oauth-client/src/store.rs index 0850617c..4b89116c 100644 --- a/atrium-oauth/oauth-client/src/store.rs +++ b/atrium-oauth/oauth-client/src/store.rs @@ -1,20 +1,3 @@ -pub mod memory; +pub mod session; +pub mod session_registry; pub mod state; - -use std::error::Error; -use std::future::Future; -use std::hash::Hash; - -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub trait SimpleStore<K, V> -where - K: Eq + Hash, - V: Clone, -{ - type Error: Error + Send + Sync + 'static; - - fn get(&self, key: &K) -> impl Future<Output = Result<Option<V>, Self::Error>>; - fn set(&self, key: K, value: V) -> impl Future<Output = Result<(), Self::Error>>; - fn del(&self, key: &K) -> impl Future<Output = Result<(), Self::Error>>; - fn clear(&self) -> impl Future<Output = Result<(), Self::Error>>; -} diff --git a/atrium-oauth/oauth-client/src/store/memory.rs b/atrium-oauth/oauth-client/src/store/memory.rs deleted file mode 100644 index c43c557d..00000000 --- a/atrium-oauth/oauth-client/src/store/memory.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::SimpleStore; -use std::collections::HashMap; -use std::fmt::Debug; -use std::hash::Hash; -use std::sync::{Arc, Mutex}; -use thiserror::Error; - -#[derive(Error, Debug)] -#[error("memory store error")] -pub struct Error; - -// TODO: LRU cache? -pub struct MemorySimpleStore<K, V> { - store: Arc<Mutex<HashMap<K, V>>>, -} - -impl<K, V> Default for MemorySimpleStore<K, V> { - fn default() -> Self { - Self { store: Arc::new(Mutex::new(HashMap::new())) } - } -} - -impl<K, V> SimpleStore<K, V> for MemorySimpleStore<K, V> -where - K: Debug + Eq + Hash + Send + Sync + 'static, - V: Debug + Clone + Send + Sync + 'static, -{ - type Error = Error; - - async fn get(&self, key: &K) -> Result<Option<V>, Self::Error> { - Ok(self.store.lock().unwrap().get(key).cloned()) - } - async fn set(&self, key: K, value: V) -> Result<(), Self::Error> { - self.store.lock().unwrap().insert(key, value); - Ok(()) - } - async fn del(&self, key: &K) -> Result<(), Self::Error> { - self.store.lock().unwrap().remove(key); - Ok(()) - } - async fn clear(&self) -> Result<(), Self::Error> { - self.store.lock().unwrap().clear(); - Ok(()) - } -} diff --git a/atrium-oauth/oauth-client/src/store/session.rs b/atrium-oauth/oauth-client/src/store/session.rs new file mode 100644 index 00000000..d6efce63 --- /dev/null +++ b/atrium-oauth/oauth-client/src/store/session.rs @@ -0,0 +1,17 @@ +use crate::types::TokenSet; +use atrium_api::types::string::Did; +use atrium_common::store::{memory::MemoryStore, Store}; +use jose_jwk::Key; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Session { + pub dpop_key: Key, + pub token_set: TokenSet, +} + +pub trait SessionStore: Store<Did, Session> {} + +pub type MemorySessionStore = MemoryStore<Did, Session>; + +impl SessionStore for MemorySessionStore {} diff --git a/atrium-oauth/oauth-client/src/store/session_registry.rs b/atrium-oauth/oauth-client/src/store/session_registry.rs new file mode 100644 index 00000000..8f1f7a2b --- /dev/null +++ b/atrium-oauth/oauth-client/src/store/session_registry.rs @@ -0,0 +1,299 @@ +use crate::{ + server_agent::OAuthServerFactory, + store::session::{Session, SessionStore}, +}; +use atrium_api::{ + did_doc::DidDocument, + types::string::{Datetime, Did, Handle}, +}; +use atrium_common::resolver::Resolver; +use atrium_xrpc::HttpClient; +use dashmap::DashMap; +use std::sync::Arc; +use thiserror::Error; +use tokio::sync::Mutex; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + ServerAgent(#[from] crate::server_agent::Error), + #[error("session store error: {0}")] + Store(String), + #[error("session does not exist")] + SessionNotFound, +} + +pub struct SessionRegistry<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, +{ + store: Arc<S>, + server_factory: Arc<OAuthServerFactory<T, D, H>>, + pending: DashMap<Did, Arc<Mutex<()>>>, +} + +impl<S, T, D, H> SessionRegistry<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, +{ + pub fn new(store: S, server_factory: Arc<OAuthServerFactory<T, D, H>>) -> Self { + let store = Arc::new(store); + Self { store: Arc::clone(&store), server_factory, pending: DashMap::new() } + } +} + +impl<S, T, D, H> SessionRegistry<S, T, D, H> +where + S: SessionStore + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, + D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync, + H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync, +{ + async fn get_refreshed(&self, key: &Did) -> Result<Session, Error> { + let lock = + self.pending.entry(key.clone()).or_insert_with(|| Arc::new(Mutex::new(()))).clone(); + let _guard = lock.lock().await; + + let mut session = self + .store + .get(key) + .await + .map_err(|e| Error::Store(e.to_string()))? + .ok_or(Error::SessionNotFound)?; + if let Some(expires_at) = &session.token_set.expires_at { + if expires_at > &Datetime::now() { + return Ok(session); + } + } + + let server = self + .server_factory + .build_from_issuer(session.dpop_key.clone(), &session.token_set.iss) + .await?; + session.token_set = server.refresh(&session.token_set).await?; + self.store + .set(key.clone(), session.clone()) + .await + .map_err(|e| Error::Store(e.to_string()))?; + Ok(session) + } + pub async fn get(&self, key: &Did, refresh: bool) -> Result<Session, Error> { + if refresh { + self.get_refreshed(key).await + } else { + // TODO: cached? + self.store + .get(key) + .await + .map_err(|e| Error::Store(e.to_string()))? + .ok_or(Error::SessionNotFound) + } + } + pub async fn set(&self, key: Did, value: Session) -> Result<(), S::Error> { + self.store.set(key.clone(), value.clone()).await + } + pub async fn del(&self, key: &Did) -> Result<(), S::Error> { + self.store.del(key).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + tests::{ + client_metadata, dpop_key, oauth_resolver, protected_resource_metadata, + server_metadata, MockDidResolver, NoopHandleResolver, + }, + types::{OAuthTokenResponse, OAuthTokenType, RefreshRequestParameters, TokenSet}, + }; + use atrium_common::store::Store; + use atrium_xrpc::http::{header::CONTENT_TYPE, Request, Response, StatusCode}; + use std::{collections::HashMap, time::Duration}; + use tokio::{sync::Mutex, time::sleep}; + + #[derive(Error, Debug)] + enum MockStoreError {} + + struct MockHttpClient { + next_token: Arc<Mutex<Option<OAuthTokenResponse>>>, + } + + impl Default for MockHttpClient { + fn default() -> Self { + Self { + next_token: Arc::new(Mutex::new(Some(OAuthTokenResponse { + access_token: String::from("new_accesstoken"), + token_type: OAuthTokenType::DPoP, + expires_in: Some(10), + refresh_token: Some(String::from("new_refreshtoken")), + scope: None, + sub: None, + }))), + } + } + } + + impl HttpClient for MockHttpClient { + async fn send_http( + &self, + request: Request<Vec<u8>>, + ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> { + println!("{:?}", request); + + Ok(match (request.uri().host(), request.uri().path()) { + (Some("iss.example.com"), "/.well-known/oauth-authorization-server") => { + Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&server_metadata())?)? + } + (Some("aud.example.com"), "/.well-known/oauth-protected-resource") => { + Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&protected_resource_metadata())?)? + } + (Some("iss.example.com"), "/token") => { + let parameters = + serde_html_form::from_bytes::<RefreshRequestParameters>(request.body())?; + if let Some(token_response) = if parameters.refresh_token == "refreshtoken" { + self.next_token.lock().await.take() + } else { + None + } { + Response::builder() + .header(CONTENT_TYPE, "application/json") + .body(serde_json::to_vec(&token_response)?)? + } else { + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", "DPoP error=\"invalid_token\"") + .body(Vec::new())? + } + } + _ => { + Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Vec::new())? + } + }) + } + } + + struct MockSessionStore { + hm: Mutex<HashMap<Did, Session>>, + } + + impl Store<Did, Session> for MockSessionStore { + type Error = MockStoreError; + + async fn get(&self, key: &Did) -> Result<Option<Session>, Self::Error> { + sleep(Duration::from_micros(10)).await; + Ok(self.hm.lock().await.get(key).cloned()) + } + async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> { + sleep(Duration::from_micros(10)).await; + self.hm.lock().await.insert(key, value); + Ok(()) + } + async fn del(&self, key: &Did) -> Result<(), Self::Error> { + sleep(Duration::from_micros(10)).await; + self.hm.lock().await.remove(key); + Ok(()) + } + async fn clear(&self) -> Result<(), Self::Error> { + unimplemented!() + } + } + + impl SessionStore for MockSessionStore {} + + impl Default for MockSessionStore { + fn default() -> Self { + Self { hm: Mutex::new(HashMap::from_iter([(did(), session())])) } + } + } + + fn did() -> Did { + "did:fake:handle.test".parse().expect("invalid did") + } + + fn session() -> Session { + let dpop_key = dpop_key(); + let token_set = TokenSet { + iss: String::from("https://iss.example.com"), + sub: did(), + aud: String::from("https://aud.example.com"), + scope: None, + refresh_token: Some(String::from("refreshtoken")), + access_token: String::from("accesstoken"), + token_type: OAuthTokenType::DPoP, + expires_at: None, + }; + Session { dpop_key, token_set } + } + + fn session_registry( + store: MockSessionStore, + ) -> SessionRegistry<MockSessionStore, MockHttpClient, MockDidResolver, NoopHandleResolver> + { + let http_client = Arc::new(MockHttpClient::default()); + SessionRegistry::new( + store, + Arc::new(OAuthServerFactory::new( + client_metadata(), + Arc::new(oauth_resolver(Arc::clone(&http_client))), + http_client, + None, + )), + ) + } + + #[tokio::test] + async fn test_get_session() -> Result<(), Box<dyn std::error::Error>> { + let registry = session_registry(MockSessionStore::default()); + let result = registry.get(&"did:fake:nonexistent".parse()?, false).await; + assert!(matches!(result, Err(Error::SessionNotFound))); + let result = registry.get(&"did:fake:handle.test".parse()?, false).await; + let session = result.expect("handle should exist"); + assert_eq!(session.token_set.access_token, "accesstoken"); + Ok(()) + } + + #[tokio::test] + async fn test_get_refreshed() -> Result<(), Box<dyn std::error::Error>> { + let registry = session_registry(MockSessionStore::default()); + let session = registry.get(&did(), true).await?; + assert_eq!(session.token_set.access_token, "new_accesstoken"); + assert_eq!(session.token_set.refresh_token.as_deref(), Some("new_refreshtoken")); + // second time should return the same session + let session = registry.get(&did(), true).await?; + assert_eq!(session.token_set.access_token, "new_accesstoken"); + assert_eq!(session.token_set.refresh_token.as_deref(), Some("new_refreshtoken")); + Ok(()) + } + + #[tokio::test] + async fn test_get_refreshed_parallel() -> Result<(), Box<dyn std::error::Error>> { + let registry = Arc::new(session_registry(MockSessionStore::default())); + let mut handles = Vec::new(); + for _ in 0..3 { + let registry = Arc::clone(®istry); + handles.push(tokio::spawn(async move { registry.get(&did(), true).await })); + } + for result in futures::future::join_all(handles).await { + match result? { + Ok(session) => { + assert_eq!(session.token_set.access_token, "new_accesstoken"); + assert_eq!( + session.token_set.refresh_token.as_deref(), + Some("new_refreshtoken") + ); + } + Err(err) => { + panic!("unexpected error: {err:?}"); + } + } + } + Ok(()) + } +} diff --git a/atrium-oauth/oauth-client/src/store/state.rs b/atrium-oauth/oauth-client/src/store/state.rs index d55e3234..a39a2cb4 100644 --- a/atrium-oauth/oauth-client/src/store/state.rs +++ b/atrium-oauth/oauth-client/src/store/state.rs @@ -1,5 +1,4 @@ -use super::memory::MemorySimpleStore; -use super::SimpleStore; +use atrium_common::store::{memory::MemoryStore, Store}; use jose_jwk::Key; use serde::{Deserialize, Serialize}; @@ -8,10 +7,11 @@ pub struct InternalStateData { pub iss: String, pub dpop_key: Key, pub verifier: String, + pub app_state: Option<String>, } -pub trait StateStore: SimpleStore<String, InternalStateData> {} +pub trait StateStore: Store<String, InternalStateData> {} -pub type MemoryStateStore = MemorySimpleStore<String, InternalStateData>; +pub type MemoryStateStore = MemoryStore<String, InternalStateData>; impl StateStore for MemoryStateStore {} diff --git a/atrium-oauth/oauth-client/src/types.rs b/atrium-oauth/oauth-client/src/types.rs index a5712674..6bd5f494 100644 --- a/atrium-oauth/oauth-client/src/types.rs +++ b/atrium-oauth/oauth-client/src/types.rs @@ -4,17 +4,13 @@ mod request; mod response; mod token; +pub use self::client_metadata::*; +pub use self::metadata::*; +pub use self::request::*; +pub use self::response::*; +pub use self::token::*; use crate::atproto::{KnownScope, Scope}; -pub use client_metadata::{OAuthClientMetadata, TryIntoOAuthClientMetadata}; -pub use metadata::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; -pub use request::{ - AuthorizationCodeChallengeMethod, AuthorizationResponseType, - PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, - TokenRequestParameters, -}; -pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse}; use serde::Deserialize; -pub use token::TokenSet; #[derive(Debug, Deserialize)] pub enum AuthorizeOptionPrompt { diff --git a/atrium-oauth/oauth-client/src/types/client_metadata.rs b/atrium-oauth/oauth-client/src/types/client_metadata.rs index 04f2f2bf..b30a23f2 100644 --- a/atrium-oauth/oauth-client/src/types/client_metadata.rs +++ b/atrium-oauth/oauth-client/src/types/client_metadata.rs @@ -2,7 +2,7 @@ use crate::keyset::Keyset; use jose_jwk::JwkSet; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct OAuthClientMetadata { pub client_id: String, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/atrium-oauth/oauth-client/src/types/metadata.rs b/atrium-oauth/oauth-client/src/types/metadata.rs index 0e40c649..2235ea25 100644 --- a/atrium-oauth/oauth-client/src/types/metadata.rs +++ b/atrium-oauth/oauth-client/src/types/metadata.rs @@ -1,7 +1,7 @@ use atrium_api::types::string::Language; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct OAuthAuthorizationServerMetadata { // https://datatracker.ietf.org/doc/html/rfc8414#section-2 pub issuer: String, @@ -50,7 +50,7 @@ pub struct OAuthAuthorizationServerMetadata { // https://datatracker.ietf.org/doc/draft-ietf-oauth-resource-metadata/ // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-08#section-2 -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct OAuthProtectedResourceMetadata { pub resource: String, pub authorization_servers: Option<Vec<String>>, diff --git a/atrium-oauth/oauth-client/src/types/request.rs b/atrium-oauth/oauth-client/src/types/request.rs index d8d352e6..b818f795 100644 --- a/atrium-oauth/oauth-client/src/types/request.rs +++ b/atrium-oauth/oauth-client/src/types/request.rs @@ -1,7 +1,6 @@ -use serde::Serialize; +use serde::{Deserialize, Serialize}; -#[allow(dead_code)] -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AuthorizationResponseType { Code, @@ -10,8 +9,7 @@ pub enum AuthorizationResponseType { IdToken, } -#[allow(dead_code)] -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AuthorizationResponseMode { Query, @@ -20,15 +18,14 @@ pub enum AuthorizationResponseMode { FormPost, } -#[allow(dead_code)] -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub enum AuthorizationCodeChallengeMethod { S256, #[serde(rename = "plain")] Plain, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub struct PushedAuthorizationRequestParameters { // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 pub response_type: AuthorizationResponseType, @@ -45,15 +42,14 @@ pub struct PushedAuthorizationRequestParameters { pub prompt: Option<String>, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum TokenGrantType { AuthorizationCode, - #[allow(dead_code)] RefreshToken, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub struct TokenRequestParameters { // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 pub grant_type: TokenGrantType, @@ -63,10 +59,18 @@ pub struct TokenRequestParameters { pub code_verifier: String, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub struct RefreshRequestParameters { // https://datatracker.ietf.org/doc/html/rfc6749#section-6 pub grant_type: TokenGrantType, pub refresh_token: String, pub scope: Option<String>, } + +// https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 +#[derive(Serialize, Deserialize)] +pub struct RevocationRequestParameters { + pub token: String, + // ? + // pub token_type_hint: Option<String>, +} diff --git a/atrium-oauth/oauth-client/src/types/token.rs b/atrium-oauth/oauth-client/src/types/token.rs index 069e9fef..d09736e0 100644 --- a/atrium-oauth/oauth-client/src/types/token.rs +++ b/atrium-oauth/oauth-client/src/types/token.rs @@ -1,11 +1,11 @@ use super::response::OAuthTokenType; -use atrium_api::types::string::Datetime; +use atrium_api::types::string::{Datetime, Did}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct TokenSet { pub iss: String, - pub sub: String, + pub sub: Did, pub aud: String, pub scope: Option<String>, diff --git a/atrium-xrpc/src/traits.rs b/atrium-xrpc/src/traits.rs index fa8d48c3..15705220 100644 --- a/atrium-xrpc/src/traits.rs +++ b/atrium-xrpc/src/traits.rs @@ -131,7 +131,7 @@ where .headers .get(http::header::CONTENT_TYPE) .and_then(|value| value.to_str().ok()) - .map_or(false, |content_type| content_type.starts_with("application/json")) + .is_some_and(|content_type| content_type.starts_with("application/json")) { Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?)) } else {