Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: OAuth session #243

Draft
wants to merge 46 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5b3d3e8
Move AtpAgent
sugyan Nov 7, 2024
71f8cff
Add Agent and SessionManager
sugyan Nov 8, 2024
6734492
Temporary fix for bsky-sdk
sugyan Nov 8, 2024
d041ae7
Add OAuthSession
sugyan Nov 13, 2024
1e7805d
Update
sugyan Nov 14, 2024
f5be54a
Update oauth_client::atproto
sugyan Nov 18, 2024
87231c0
Merge branch 'main' into feature/agent-rework
sugyan Nov 18, 2024
0676695
Add refresh token request
sugyan Nov 21, 2024
1933080
Merge branch 'main' into feature/agent-rework
sugyan Nov 21, 2024
a109ef5
Update stores
sugyan Nov 21, 2024
a395d90
Add SessionStore and SessionGetter
sugyan Nov 21, 2024
5bad0a1
Use common for api
sugyan Nov 26, 2024
04fd7dc
Fix bsky-sdk
sugyan Nov 26, 2024
fc3ef22
Fix oauth-client
sugyan Nov 28, 2024
7336b8c
Extract WrapperClient and InnerStore to agent.rs
sugyan Dec 31, 2024
d987bb8
Add Configure and CloneWithProxy trait for agent
sugyan Jan 1, 2025
8810bc7
Update AtpAgent
sugyan Jan 4, 2025
f00b866
Merge branch 'main' into feature/agent-rework
sugyan Jan 4, 2025
26f1d04
Update
sugyan Jan 15, 2025
870f85a
Add tests for oauth_session, implement oauth_session::store
sugyan Jan 16, 2025
b47bfa1
WIP: Add tests for OAuthSession, update ServerAgent
sugyan Jan 22, 2025
ba3ada2
Implement refresh token
sugyan Jan 30, 2025
498b6be
Update atrium-api/agent
sugyan Feb 13, 2025
5aa65b5
Merge branch 'main' into feature/agent-rework
sugyan Feb 13, 2025
bf05936
Fix error
sugyan Feb 13, 2025
e48bace
Fix bsky_sdk
sugyan Feb 13, 2025
ae297d3
Fix oauth-client
sugyan Feb 14, 2025
2aa7282
Merge branch 'main' into feature/agent-rework
sugyan Feb 18, 2025
54b1daa
Merge branch 'main' into feature/agent-rework
sugyan Feb 23, 2025
6b5633e
Implement SessionGetter and SessionHandle
sugyan Feb 25, 2025
0939b0e
Remove unused code
sugyan Feb 26, 2025
5792b4d
Fix workflows
sugyan Feb 26, 2025
63579f6
Use is_some_and/is_ok_and
sugyan Feb 26, 2025
91713aa
Add tests for OAuthSession
sugyan Feb 27, 2025
0adb558
Implement OAuthClient::restore
sugyan Mar 2, 2025
198b171
Implement OAuthClient::revoke()
sugyan Mar 3, 2025
590968e
Rename session_getter to session_registry
sugyan Mar 3, 2025
ea58a43
Fix for edition 2024
sugyan Mar 4, 2025
cada5f4
Merge branch 'main' into feature/agent-rework
sugyan Mar 8, 2025
5f15158
WIP: Update SessionRegistry
sugyan Mar 15, 2025
7ec0d0d
Update SessionRegistry
sugyan Mar 16, 2025
295c3d3
WIP
sugyan Mar 22, 2025
54e8526
Add session_registry::tests
sugyan Mar 27, 2025
26e9839
Fix tests
sugyan Mar 27, 2025
13f62be
Remove DidResolver
sugyan Mar 28, 2025
ab1c071
Remove HandleResolver
sugyan Mar 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/wasm.yml
Original file line number Diff line number Diff line change
@@ -20,6 +20,10 @@ jobs:
- uses: actions/checkout@v4
with:
path: crates
- name: Install Rust 1.75.0
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: 1.75.0
# We use a synthetic crate to ensure no dev-dependencies are enabled, which can
# be incompatible with some of these targets.
- name: Create synthetic crate for testing
10 changes: 6 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions atrium-api/src/agent/atp_agent.rs
Original file line number Diff line number Diff line change
@@ -324,6 +324,7 @@ mod tests {
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
// tick tokio time
#[cfg(not(target_arch = "wasm32"))]
tokio::time::sleep(std::time::Duration::from_micros(10)).await;

16 changes: 0 additions & 16 deletions atrium-api/src/agent/store.rs

This file was deleted.

20 changes: 0 additions & 20 deletions atrium-api/src/agent/store/memory.rs

This file was deleted.

12 changes: 6 additions & 6 deletions atrium-common/src/store/memory.rs
Original file line number Diff line number Diff line change
@@ -2,14 +2,14 @@ use super::Store;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Mutex;

#[derive(Error, Debug)]
#[error("memory store error")]
pub struct Error;

// TODO: LRU cache?
#[derive(Clone)]
pub struct MemoryStore<K, V> {
store: Arc<Mutex<HashMap<K, V>>>,
@@ -29,18 +29,18 @@ where
type Error = Error;

async fn get(&self, key: &K) -> Result<Option<V>, Self::Error> {
Ok(self.store.lock().unwrap().get(key).cloned())
Ok(self.store.lock().await.get(key).cloned())
}
async fn set(&self, key: K, value: V) -> Result<(), Self::Error> {
self.store.lock().unwrap().insert(key, value);
self.store.lock().await.insert(key, value);
Ok(())
}
async fn del(&self, key: &K) -> Result<(), Self::Error> {
self.store.lock().unwrap().remove(key);
self.store.lock().await.remove(key);
Ok(())
}
async fn clear(&self) -> Result<(), Self::Error> {
self.store.lock().unwrap().clear();
self.store.lock().await.clear();
Ok(())
}
}
3 changes: 1 addition & 2 deletions atrium-identity/src/did.rs
Original file line number Diff line number Diff line change
@@ -2,10 +2,9 @@ mod common_resolver;
mod plc_resolver;
mod web_resolver;

use crate::Error;

pub use self::common_resolver::{CommonDidResolver, CommonDidResolverConfig};
pub use self::plc_resolver::DEFAULT_PLC_DIRECTORY_URL;
use crate::Error;
use atrium_api::did_doc::DidDocument;
use atrium_api::types::string::Did;
use atrium_common::resolver::Resolver;
10 changes: 6 additions & 4 deletions atrium-identity/src/identity_resolver.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::error::{Error, Result};
use crate::{did::DidResolver, handle::HandleResolver};
use atrium_api::types::string::AtIdentifier;
use atrium_api::{
did_doc::DidDocument,
types::string::{AtIdentifier, Did, Handle},
};
use atrium_common::resolver::Resolver;
use serde::{Deserialize, Serialize};

@@ -29,8 +31,8 @@ impl<D, H> IdentityResolver<D, H> {

impl<D, H> Resolver for IdentityResolver<D, H>
where
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
{
type Input = str;
type Output = ResolvedIdentity;
8 changes: 6 additions & 2 deletions atrium-oauth/oauth-client/Cargo.toml
Original file line number Diff line number Diff line change
@@ -14,12 +14,13 @@ keywords = ["atproto", "bluesky", "oauth"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
atrium-api = { workspace = true, default-features = false }
atrium-api = { workspace = true, features = ["agent"] }
atrium-common.workspace = true
atrium-identity.workspace = true
atrium-xrpc.workspace = true
base64.workspace = true
chrono.workspace = true
dashmap.workspace = true
ecdsa = { workspace = true, features = ["signing"] }
elliptic-curve.workspace = true
jose-jwa.workspace = true
@@ -32,11 +33,14 @@ serde_html_form.workspace = true
serde_json.workspace = true
sha2.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["sync"] }
trait-variant.workspace = true

[dev-dependencies]
atrium-api = { workspace = true, features = ["bluesky"] }
futures.workspace = true
hickory-resolver.workspace = true
p256 = { workspace = true, features = ["pem"] }
p256 = { workspace = true, features = ["pem", "std"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }

[features]
28 changes: 28 additions & 0 deletions atrium-oauth/oauth-client/examples/generate_key.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use elliptic_curve::pkcs8::EncodePrivateKey;
use elliptic_curve::SecretKey;
use jose_jwa::{Algorithm, Signing};
use jose_jwk::{Class, Jwk, JwkSet, Key, Parameters};
use p256::NistP256;
use rand::rngs::ThreadRng;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let secret_key = SecretKey::<NistP256>::random(&mut ThreadRng::default());
let key = Key::from(&secret_key.public_key().into());
let jwks = JwkSet {
keys: vec![Jwk {
key,
prm: Parameters {
alg: Some(Algorithm::Signing(Signing::Es256)),
kid: Some(String::from("kid01")),
cls: Some(Class::Signing),
..Default::default()
},
}],
};
println!("SECRET KEY:");
println!("{}", secret_key.to_pkcs8_pem(Default::default())?.as_str());

println!("JWKS:");
println!("{}", serde_json::to_string_pretty(&jwks)?);
Ok(())
}
26 changes: 23 additions & 3 deletions atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use atrium_api::agent::Agent;
use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver};
use atrium_oauth_client::store::session::MemorySessionStore;
use atrium_oauth_client::store::state::MemoryStateStore;
use atrium_oauth_client::{
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient,
@@ -57,6 +59,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
protected_resource_metadata: Default::default(),
},
state_store: MemoryStateStore::default(),
session_store: MemorySessionStore::default(),
};
let client = OAuthClient::new(config)?;
println!(
@@ -76,7 +79,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
);

// Click the URL and sign in,
// then copy and paste the URL like “http://127.0.0.1/?iss=...&code=...” after it is redirected.
// then copy and paste the URL like “http://127.0.0.1/callback?iss=...&code=...” after it is redirected.

print!("Redirected url: ");
stdout().lock().flush()?;
@@ -85,7 +88,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let uri = url.trim().parse::<Uri>()?;
let params = serde_html_form::from_str(uri.query().unwrap())?;
println!("{}", serde_json::to_string_pretty(&client.callback(params).await?)?);

let (session, _) = client.callback(params).await?;
let agent = Agent::new(session);
let output = agent
.api
.app
.bsky
.feed
.get_timeline(
atrium_api::app::bsky::feed::get_timeline::ParametersData {
algorithm: None,
cursor: None,
limit: 3.try_into().ok(),
}
.into(),
)
.await?;
for feed in &output.feed {
println!("{feed:?}");
}
Ok(())
}
14 changes: 11 additions & 3 deletions atrium-oauth/oauth-client/src/error.rs
Original file line number Diff line number Diff line change
@@ -5,17 +5,25 @@ pub enum Error {
#[error(transparent)]
ClientMetadata(#[from] crate::atproto::Error),
#[error(transparent)]
Keyset(#[from] crate::keyset::Error),
Dpop(#[from] crate::http_client::dpop::Error),
#[error(transparent)]
Identity(#[from] atrium_identity::Error),
Keyset(#[from] crate::keyset::Error),
#[error(transparent)]
ServerAgent(#[from] crate::server_agent::Error),
#[error(transparent)]
OAuthSession(#[from] crate::oauth_session::Error),
#[error(transparent)]
SessionRegistry(#[from] crate::store::session_registry::Error),
#[error(transparent)]
Identity(#[from] atrium_identity::Error),
#[error("authorize error: {0}")]
Authorize(String),
#[error("callback error: {0}")]
Callback(String),
#[error("state store error: {0:?}")]
#[error("state store error: {0}")]
StateStore(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("session store error: {0}")]
SessionStore(Box<dyn std::error::Error + Send + Sync + 'static>),
}

pub type Result<T> = core::result::Result<T, Error>;
50 changes: 33 additions & 17 deletions atrium-oauth/oauth-client/src/http_client/dpop.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::jose::create_signed_jwt;
use crate::jose::jws::RegisteredHeader;
use crate::jose::jwt::{Claims, PublicClaims, RegisteredClaims};
use crate::store::memory::MemorySimpleStore;
use crate::store::SimpleStore;
use atrium_xrpc::http::{Request, Response};
use atrium_xrpc::HttpClient;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use crate::jose::{
create_signed_jwt,
jws::RegisteredHeader,
jwt::{Claims, PublicClaims, RegisteredClaims},
};
use atrium_common::store::{memory::MemoryStore, Store};
use atrium_xrpc::{
http::{Request, Response},
HttpClient,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use chrono::Utc;
use jose_jwa::{Algorithm, Signing};
use jose_jwk::{crypto, EcCurves, Jwk, Key};
use rand::rngs::SmallRng;
use rand::{RngCore, SeedableRng};
use rand::{
rngs::SmallRng,
{RngCore, SeedableRng},
};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::sync::Arc;
@@ -36,9 +40,9 @@ pub enum Error {

type Result<T> = core::result::Result<T, Error>;

pub struct DpopClient<T, S = MemorySimpleStore<String, String>>
pub struct DpopClient<T, S = MemoryStore<String, String>>
where
S: SimpleStore<String, String>,
S: Store<String, String>,
{
inner: Arc<T>,
pub(crate) key: Key,
@@ -65,14 +69,14 @@ impl<T> DpopClient<T> {
return Err(Error::UnsupportedKey);
}
}
let nonces = MemorySimpleStore::<String, String>::default();
let nonces = MemoryStore::<String, String>::default();
Ok(Self { inner: http_client, key, nonces, is_auth_server })
}
}

impl<T, S> DpopClient<T, S>
where
S: SimpleStore<String, String>,
S: Store<String, String>,
{
fn build_proof(
&self,
@@ -135,7 +139,8 @@ where
impl<T, S> HttpClient for DpopClient<T, S>
where
T: HttpClient + Send + Sync + 'static,
S: SimpleStore<String, String> + Send + Sync + 'static,
S: Store<String, String> + Send + Sync + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
{
async fn send_http(
&self,
@@ -150,7 +155,7 @@ where
let ath = request
.headers()
.get("Authorization")
.filter(|v| v.to_str().map_or(false, |s| s.starts_with("DPoP ")))
.filter(|v| v.to_str().is_ok_and(|s| s.starts_with("DPoP ")))
.map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..])));

let init_nonce = self.nonces.get(&nonce_key).await?;
@@ -182,3 +187,14 @@ where
Ok(response)
}
}

impl<T> Clone for DpopClient<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
key: self.key.clone(),
nonces: self.nonces.clone(),
is_auth_server: self.is_auth_server,
}
}
}
2 changes: 1 addition & 1 deletion atrium-oauth/oauth-client/src/keyset.rs
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ impl Keyset {
.0
.iter()
.filter_map(|key| {
if key.prm.cls.map_or(false, |c| c != cls) {
if key.prm.cls.is_some_and(|c| c != cls) {
return None;
}
let alg = match &key.key {
110 changes: 110 additions & 0 deletions atrium-oauth/oauth-client/src/lib.rs
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ mod http_client;
mod jose;
mod keyset;
mod oauth_client;
mod oauth_session;
mod resolver;
mod server_agent;
pub mod store;
@@ -19,7 +20,116 @@ pub use error::{Error, Result};
pub use http_client::default::DefaultHttpClient;
pub use http_client::dpop::DpopClient;
pub use oauth_client::{OAuthClient, OAuthClientConfig};
pub use oauth_session::OAuthSession;
pub use resolver::OAuthResolverConfig;
pub use types::{
AuthorizeOptionPrompt, AuthorizeOptions, CallbackParams, OAuthClientMetadata, TokenSet,
};

#[cfg(test)]
mod tests {
use crate::{
resolver::OAuthResolver,
types::{
OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthProtectedResourceMetadata,
TryIntoOAuthClientMetadata,
},
AtprotoLocalhostClientMetadata, OAuthResolverConfig,
};
use atrium_api::{
did_doc::{DidDocument, Service},
types::string::{Did, Handle},
};
use atrium_common::resolver::Resolver;
use atrium_xrpc::HttpClient;
use jose_jwk::Key;
use std::sync::Arc;

pub struct MockDidResolver;

impl Resolver for MockDidResolver {
type Input = Did;
type Output = DidDocument;
type Error = atrium_identity::Error;
async fn resolve(&self, did: &Self::Input) -> Result<Self::Output, Self::Error> {
Ok(DidDocument {
context: None,
id: did.as_ref().to_string(),
also_known_as: None,
verification_method: None,
service: Some(vec![Service {
id: String::from("#atproto_pds"),
r#type: String::from("AtprotoPersonalDataServer"),
service_endpoint: String::from("https://aud.example.com"),
}]),
})
}
}

pub struct NoopHandleResolver;

impl Resolver for NoopHandleResolver {
type Input = Handle;
type Output = Did;
type Error = atrium_identity::Error;
async fn resolve(&self, _: &Self::Input) -> Result<Self::Output, Self::Error> {
unimplemented!()
}
}

pub fn oauth_resolver<T>(
http_client: Arc<T>,
) -> OAuthResolver<T, MockDidResolver, NoopHandleResolver>
where
T: HttpClient + Send + Sync,
{
OAuthResolver::new(
OAuthResolverConfig {
did_resolver: MockDidResolver,
handle_resolver: NoopHandleResolver,
authorization_server_metadata: Default::default(),
protected_resource_metadata: Default::default(),
},
http_client,
)
}

pub fn dpop_key() -> Key {
serde_json::from_str(
r#"{
"kty": "EC",
"crv": "P-256",
"x": "NIRNgPVAwnVNzN5g2Ik2IMghWcjnBOGo9B-lKXSSXFs",
"y": "iWF-Of43XoSTZxcadO9KWdPTjiCoviSztYw7aMtZZMc",
"d": "9MuCYfKK4hf95p_VRj6cxKJwORTgvEU3vynfmSgFH2M"
}"#,
)
.expect("key should be valid")
}

pub fn server_metadata() -> OAuthAuthorizationServerMetadata {
OAuthAuthorizationServerMetadata {
issuer: String::from("https://iss.example.com"),
token_endpoint: String::from("https://iss.example.com/token"),
token_endpoint_auth_methods_supported: Some(vec![
String::from("none"),
String::from("private_key_jwt"),
]),
..Default::default()
}
}

pub fn client_metadata() -> OAuthClientMetadata {
AtprotoLocalhostClientMetadata::default()
.try_into_client_metadata(&None)
.expect("client metadata should be valid")
}

pub fn protected_resource_metadata() -> OAuthProtectedResourceMetadata {
OAuthProtectedResourceMetadata {
resource: String::from("https://aud.example.com"),
authorization_servers: Some(vec![String::from("https://iss.example.com")]),
..Default::default()
}
}
}
205 changes: 144 additions & 61 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
@@ -1,138 +1,184 @@
use crate::constants::FALLBACK_ALG;
use crate::error::{Error, Result};
use crate::keyset::Keyset;
use crate::resolver::{OAuthResolver, OAuthResolverConfig};
use crate::server_agent::{OAuthRequest, OAuthServerAgent};
use crate::store::state::{InternalStateData, StateStore};
use crate::types::{
AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions, CallbackParams,
OAuthAuthorizationServerMetadata, OAuthClientMetadata,
OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, TokenSet,
TryIntoOAuthClientMetadata,
use crate::{
constants::FALLBACK_ALG,
error::{Error, Result},
keyset::Keyset,
oauth_session::OAuthSession,
resolver::{OAuthResolver, OAuthResolverConfig},
server_agent::{OAuthRequest, OAuthServerAgent, OAuthServerFactory},
store::{
session::{Session, SessionStore},
session_registry::SessionRegistry,
state::{InternalStateData, StateStore},
},
types::{
AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions,
CallbackParams, OAuthAuthorizationServerMetadata, OAuthClientMetadata,
OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters,
TryIntoOAuthClientMetadata,
},
utils::{compare_algos, generate_key, generate_nonce},
};
use atrium_api::{
did_doc::DidDocument,
types::string::{Did, Handle},
};
use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values};
use atrium_common::resolver::Resolver;
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_xrpc::HttpClient;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use jose_jwk::{Jwk, JwkSet, Key};
use rand::rngs::ThreadRng;
use serde::Serialize;
use sha2::{Digest, Sha256};
use std::sync::Arc;

#[cfg(feature = "default-client")]
pub struct OAuthClientConfig<S, M, D, H>
pub struct OAuthClientConfig<S0, S1, M, D, H>
where
M: TryIntoOAuthClientMetadata,
{
// Config
pub client_metadata: M,
pub keys: Option<Vec<Jwk>>,
// Stores
pub state_store: S,
pub state_store: S0,
pub session_store: S1,
// Services
pub resolver: OAuthResolverConfig<D, H>,
}

#[cfg(not(feature = "default-client"))]
pub struct OAuthClientConfig<S, T, M, D, H>
pub struct OAuthClientConfig<S0, S1, T, M, D, H>
where
M: TryIntoOAuthClientMetadata,
{
// Config
pub client_metadata: M,
pub keys: Option<Vec<Jwk>>,
// Stores
pub state_store: S,
pub state_store: S0,
pub session_store: S1,
// Services
pub resolver: OAuthResolverConfig<D, H>,
// Others
pub http_client: T,
}

#[cfg(feature = "default-client")]
pub struct OAuthClient<S, D, H, T = crate::http_client::default::DefaultHttpClient>
pub struct OAuthClient<S0, S1, D, H, T = crate::http_client::default::DefaultHttpClient>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
S1: SessionStore + Send + Sync + 'static,
S1::Error: std::error::Error + Send + Sync + 'static,
{
pub client_metadata: OAuthClientMetadata,
keyset: Option<Keyset>,
resolver: Arc<OAuthResolver<T, D, H>>,
state_store: S,
server_factory: Arc<OAuthServerFactory<T, D, H>>,
state_store: S0,
session_registry: Arc<SessionRegistry<S1, T, D, H>>,
http_client: Arc<T>,
}

#[cfg(not(feature = "default-client"))]
pub struct OAuthClient<S, D, H, T>
pub struct OAuthClient<S0, S1, D, H, T>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
S1: SessionStore + Send + Sync + 'static,
S1::Error: std::error::Error + Send + Sync + 'static,
{
pub client_metadata: OAuthClientMetadata,
keyset: Option<Keyset>,
resolver: Arc<OAuthResolver<T, D, H>>,
state_store: S,
server_factory: Arc<OAuthServerFactory<T, D, H>>,
state_store: S0,
session_registry: Arc<SessionRegistry<S1, T, D, H>>,
http_client: Arc<T>,
}

#[cfg(feature = "default-client")]
impl<S, D, H> OAuthClient<S, D, H, crate::http_client::default::DefaultHttpClient>
impl<S0, S1, D, H> OAuthClient<S0, S1, D, H, crate::http_client::default::DefaultHttpClient>
where
S: StateStore,
S1: SessionStore + Send + Sync + 'static,
S1::Error: std::error::Error + Send + Sync + 'static,
{
pub fn new<M>(config: OAuthClientConfig<S, M, D, H>) -> Result<Self>
pub fn new<M>(config: OAuthClientConfig<S0, S1, M, D, H>) -> Result<Self>
where
M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>,
{
let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None };
let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?;
let http_client = Arc::new(crate::http_client::default::DefaultHttpClient::default());
let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client)));
let server_factory = Arc::new(OAuthServerFactory::new(
client_metadata.clone(),
Arc::clone(&resolver),
Arc::clone(&http_client),
keyset.clone(),
));
let session_registry =
Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory)));
Ok(Self {
client_metadata,
keyset,
resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())),
resolver,
server_factory,
state_store: config.state_store,
session_registry,
http_client,
})
}
}

#[cfg(not(feature = "default-client"))]
impl<S, D, H, T> OAuthClient<S, D, H, T>
impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
S1: SessionStore + Send + Sync + 'static,
S1::Error: std::error::Error + Send + Sync + 'static,
{
pub fn new<M>(config: OAuthClientConfig<S, T, M, D, H>) -> Result<Self>
pub fn new<M>(config: OAuthClientConfig<S0, S1, T, M, D, H>) -> Result<Self>
where
M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>,
{
let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None };
let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?;
let http_client = Arc::new(config.http_client);
let resolver = Arc::new(OAuthResolver::new(config.resolver, Arc::clone(&http_client)));
let server_factory = Arc::new(OAuthServerFactory::new(
client_metadata.clone(),
Arc::clone(&resolver),
Arc::clone(&http_client),
keyset.clone(),
));
let session_registry =
Arc::new(SessionRegistry::new(config.session_store, Arc::clone(&server_factory)));
Ok(Self {
client_metadata,
keyset,
resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())),
resolver,
server_factory,
state_store: config.state_store,
session_registry,
http_client,
})
}
}

impl<S, D, H, T> OAuthClient<S, D, H, T>
impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T>
where
S: StateStore,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
S0: StateStore + Send + Sync + 'static,
S1: SessionStore + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
T: HttpClient + Send + Sync + 'static,
S0::Error: std::error::Error + Send + Sync + 'static,
S1::Error: std::error::Error + Send + Sync + 'static,
{
pub fn jwks(&self) -> JwkSet {
self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default()
}
/// Start the authorization process.
///
/// This method will return a URL that the user should be redirected to.
pub async fn authorize(
&self,
input: impl AsRef<str>,
@@ -156,6 +202,7 @@ where
iss: metadata.issuer.clone(),
dpop_key: dpop_key.clone(),
verifier,
app_state: options.state,
};
self.state_store
.set(state.clone(), state_data)
@@ -174,14 +221,7 @@ where
prompt: options.prompt.map(String::from),
};
if metadata.pushed_authorization_request_endpoint.is_some() {
let server = OAuthServerAgent::new(
dpop_key,
metadata.clone(),
self.client_metadata.clone(),
self.resolver.clone(),
self.http_client.clone(),
self.keyset.clone(),
)?;
let server = self.server_factory.build_from_metadata(dpop_key, metadata.clone())?;
let par_response = server
.request::<OAuthPusehedAuthorizationRequestResponse>(
OAuthRequest::PushedAuthorizationRequest(parameters),
@@ -208,7 +248,14 @@ where
todo!()
}
}
pub async fn callback(&self, params: CallbackParams) -> Result<TokenSet> {
/// Handle the callback from the authorization server.
///
/// This method will exchange the authorization code for an access token and store the session,
/// and return the [`OAuthSession`] and the application state.
pub async fn callback(
&self,
params: CallbackParams,
) -> Result<(OAuthSession<T, D, H, S1>, Option<String>)> {
let Some(state_key) = params.state else {
return Err(Error::Callback("missing `state` parameter".into()));
};
@@ -233,18 +280,55 @@ where
} else if metadata.authorization_response_iss_parameter_supported == Some(true) {
return Err(Error::Callback("missing `iss` parameter".into()));
}
let server = OAuthServerAgent::new(
state.dpop_key.clone(),
metadata.clone(),
self.client_metadata.clone(),
self.resolver.clone(),
self.http_client.clone(),
self.keyset.clone(),
)?;
let token_set = server.exchange_code(&params.code, &state.verifier).await?;

// TODO: create session?
Ok(token_set)
let server =
self.server_factory.build_from_metadata(state.dpop_key.clone(), metadata.clone())?;
match server.exchange_code(&params.code, &state.verifier).await {
Ok(token_set) => {
let sub = token_set.sub.clone();
self.session_registry
.set(sub.clone(), Session { dpop_key: state.dpop_key.clone(), token_set })
.await
.map_err(|e| Error::SessionStore(Box::new(e)))?;
Ok((self.create_session(server, &sub).await?, state.app_state))
}
Err(_) => {
todo!()
}
}
}
/// Load a stored session by giving the subject DID.
///
/// This method will return the [`OAuthSession`] if it exists.
pub async fn restore(&self, sub: &Did) -> Result<OAuthSession<T, D, H, S1>> {
// let session_handle = self.session_registry.get(sub).await?;
// let session = session_handle.read().await.session();
let session = self.session_registry.get(sub, false).await?;
self.create_session(
self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?,
sub,
)
.await
}
/// Revoke a session by giving the subject DID.
pub async fn revoke(&self, sub: &Did) -> Result<()> {
let session = self.session_registry.get(sub, false).await?;
let server_agent =
self.server_factory.build_from_issuer(session.dpop_key, &session.token_set.iss).await?;
server_agent.revoke(&session.token_set.access_token).await?;
self.session_registry.del(sub).await.map_err(|e| Error::SessionStore(Box::new(e)))
}
async fn create_session(
&self,
server: OAuthServerAgent<T, D, H>,
sub: &Did,
) -> Result<OAuthSession<T, D, H, S1>> {
Ok(OAuthSession::new(
server.server_metadata.clone(),
sub.clone(),
Arc::clone(&self.http_client),
Arc::clone(&self.session_registry),
)
.await?)
}
fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option<Key> {
let mut algs =
@@ -254,8 +338,7 @@ where
}
fn generate_pkce() -> (String, String) {
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
let verifier =
URL_SAFE_NO_PAD.encode(get_random_values::<_, 32>(&mut ThreadRng::default()));
let verifier = [generate_nonce(), generate_nonce()].join("");
(URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier)
}
}
664 changes: 664 additions & 0 deletions atrium-oauth/oauth-client/src/oauth_session.rs

Large diffs are not rendered by default.

154 changes: 154 additions & 0 deletions atrium-oauth/oauth-client/src/oauth_session/inner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use super::store::MemorySessionStore;
use crate::{
store::{session::SessionStore, session_registry::SessionRegistry},
DpopClient,
};
use atrium_api::{
agent::{
utils::{SessionClient, SessionWithEndpointStore},
CloneWithProxy, Configure,
},
did_doc::DidDocument,
types::string::{Did, Handle},
};
use atrium_common::resolver::Resolver;
use atrium_xrpc::{
http::{Request, Response},
Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{fmt::Debug, sync::Arc};

pub struct Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
inner: SessionClient<MemorySessionStore, DpopClient<T>, String>,
store: Arc<SessionWithEndpointStore<MemorySessionStore, String>>,
sub: Did,
session_registry: Arc<SessionRegistry<S, T, D, H>>,
}

impl<S, T, D, H> Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
pub fn new(
store: Arc<SessionWithEndpointStore<MemorySessionStore, String>>,
xrpc: DpopClient<T>,
sub: Did,
session_registry: Arc<SessionRegistry<S, T, D, H>>,
) -> Self {
let inner = SessionClient::new(Arc::clone(&store), xrpc);
Self { inner, store, sub, session_registry }
}
}

impl<S, T, D, H> Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
{
// https://datatracker.ietf.org/doc/html/rfc6750#section-3
fn is_invalid_token_response<O, E>(result: &Result<OutputDataOrBytes<O>, Error<E>>) -> bool
where
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
{
match result {
Err(Error::Authentication(value)) => value
.to_str()
.is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
_ => false,
}
}
async fn refresh_token(&self) {
if let Ok(session) = self.session_registry.get(&self.sub, true).await {
let _ = self.store.set(session.token_set.access_token.clone()).await;
}
}
}

impl<S, T, D, H> HttpClient for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
D: Send + Sync,
H: Send + Sync,
{
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.send_http(request).await
}
}

impl<S, T, D, H> XrpcClient for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
{
fn base_uri(&self) -> String {
self.inner.base_uri()
}
async fn send_xrpc<P, I, O, E>(
&self,
request: &XrpcRequest<P, I>,
) -> Result<OutputDataOrBytes<O>, Error<E>>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync + Debug,
{
let result = self.inner.send_xrpc(request).await;
// handle session-refreshes as needed
if Self::is_invalid_token_response(&result) {
self.refresh_token().await;
self.inner.send_xrpc(request).await
} else {
result
}
}
}

impl<S, T, D, H> Configure for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
fn configure_endpoint(&self, endpoint: String) {
self.inner.configure_endpoint(endpoint)
}
/// Configures the moderation services to be applied on requests.
fn configure_labelers_header(&self, labeler_dids: Option<Vec<(Did, bool)>>) {
self.inner.configure_labelers_header(labeler_dids)
}
/// Configures the atproto-proxy header to be applied on requests.
fn configure_proxy_header(&self, did: Did, service_type: impl AsRef<str>) {
self.inner.configure_proxy_header(did, service_type)
}
}

impl<S, T, D, H> CloneWithProxy for Client<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
SessionClient<S, T, String>: CloneWithProxy,
{
fn clone_with_proxy(&self, did: Did, service_type: impl AsRef<str>) -> Self {
Self {
inner: self.inner.clone_with_proxy(did, service_type),
store: Arc::clone(&self.store),
sub: self.sub.clone(),
session_registry: Arc::clone(&self.session_registry),
}
}
}
29 changes: 29 additions & 0 deletions atrium-oauth/oauth-client/src/oauth_session/store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use atrium_api::agent::AuthorizationProvider;
use atrium_common::store::{self, memory::MemoryStore, Store};
use atrium_xrpc::types::AuthorizationToken;

#[derive(Default)]
pub struct MemorySessionStore(MemoryStore<(), String>);

impl Store<(), String> for MemorySessionStore {
type Error = store::memory::Error;

async fn get(&self, key: &()) -> Result<Option<String>, Self::Error> {
self.0.get(key).await
}
async fn set(&self, key: (), value: String) -> Result<(), Self::Error> {
self.0.set(key, value).await
}
async fn del(&self, key: &()) -> Result<(), Self::Error> {
self.0.del(key).await
}
async fn clear(&self) -> Result<(), Self::Error> {
self.0.clear().await
}
}

impl AuthorizationProvider for MemorySessionStore {
async fn authorization_token(&self, _: bool) -> Option<AuthorizationToken> {
self.0.get(&()).await.ok().flatten().map(AuthorizationToken::Dpop)
}
}
41 changes: 22 additions & 19 deletions atrium-oauth/oauth-client/src/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
use atrium_common::resolver::CachedResolver;
use atrium_common::resolver::Resolver;
use atrium_common::resolver::ThrottledResolver;
use atrium_common::types::cached::r#impl::Cache;
use atrium_common::types::cached::r#impl::CacheImpl;
use atrium_common::types::cached::CacheConfig;
use atrium_common::types::cached::Cacheable;
use atrium_common::types::throttled::Throttleable;
mod oauth_authorization_server_resolver;
mod oauth_protected_resource_resolver;

use self::oauth_authorization_server_resolver::DefaultOAuthAuthorizationServerResolver;
use self::oauth_protected_resource_resolver::DefaultOAuthProtectedResourceResolver;
use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata};
use atrium_identity::identity_resolver::{
IdentityResolver, IdentityResolverConfig, ResolvedIdentity,
use atrium_api::{
did_doc::DidDocument,
types::string::{Did, Handle},
};
use atrium_common::{
resolver::{CachedResolver, Resolver, ThrottledResolver},
types::{
cached::{
r#impl::{Cache, CacheImpl},
{CacheConfig, Cacheable},
},
throttled::Throttleable,
},
};
use atrium_identity::{
identity_resolver::{IdentityResolver, IdentityResolverConfig, ResolvedIdentity},
{Error, Result},
};
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_identity::{Error, Result};
use atrium_xrpc::HttpClient;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use std::{marker::PhantomData, sync::Arc, time::Duration};

#[derive(Clone, Debug)]
pub struct OAuthAuthorizationServerMetadataResolverConfig {
@@ -106,8 +109,8 @@ where
impl<T, D, H> OAuthResolver<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
{
pub async fn get_authorization_server_metadata(
&self,
@@ -185,8 +188,8 @@ where
impl<T, D, H> Resolver for OAuthResolver<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
{
type Input = str;
type Output = (OAuthAuthorizationServerMetadata, Option<ResolvedIdentity>);
198 changes: 145 additions & 53 deletions atrium-oauth/oauth-client/src/server_agent.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
use crate::constants::FALLBACK_ALG;
use crate::http_client::dpop::DpopClient;
use crate::jose::jwt::{RegisteredClaims, RegisteredClaimsAud};
use crate::keyset::Keyset;
use crate::resolver::OAuthResolver;
use crate::types::{
OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse,
PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType,
TokenRequestParameters, TokenSet,
mod factory;

pub use self::factory::OAuthServerFactory;
use crate::{
constants::FALLBACK_ALG,
http_client::dpop::DpopClient,
jose::jwt::{RegisteredClaims, RegisteredClaimsAud},
keyset::Keyset,
resolver::OAuthResolver,
types::{
OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse,
PushedAuthorizationRequestParameters, RefreshRequestParameters,
RevocationRequestParameters, TokenGrantType, TokenRequestParameters, TokenSet,
},
utils::{compare_algos, generate_nonce},
};
use atrium_api::{
did_doc::DidDocument,
types::string::{Datetime, Did, Handle},
};
use atrium_common::resolver::Resolver;
use atrium_xrpc::{
http::{Method, Request, StatusCode},
HttpClient,
};
use crate::utils::{compare_algos, generate_nonce};
use atrium_api::types::string::Datetime;
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_xrpc::http::{Method, Request, StatusCode};
use atrium_xrpc::HttpClient;
use chrono::{TimeDelta, Utc};
use jose_jwk::Key;
use serde::Serialize;
@@ -32,8 +42,14 @@ pub enum Error {
Token(String),
#[error("unsupported authentication method")]
UnsupportedAuthMethod,
#[error("no refresh token available")]
TokenRefresh,
#[error("failed to parse DID: {0}")]
InvalidDid(&'static str),
#[error(transparent)]
DpopClient(#[from] crate::http_client::dpop::Error),
// #[error(transparent)]
// OAuthSession(#[from] crate::oauth_session::Error),
#[error(transparent)]
Http(#[from] atrium_xrpc::http::Error),
#[error("http client error: {0}")]
@@ -58,7 +74,7 @@ pub type Result<T> = core::result::Result<T, Error>;
pub enum OAuthRequest {
Token(TokenRequestParameters),
Refresh(RefreshRequestParameters),
Revocation,
Revocation(RevocationRequestParameters),
Introspection,
PushedAuthorizationRequest(PushedAuthorizationRequestParameters),
}
@@ -68,7 +84,7 @@ impl OAuthRequest {
String::from(match self {
Self::Token(_) => "token",
Self::Refresh(_) => "refresh",
Self::Revocation => "revocation",
Self::Revocation(_) => "revocation",
Self::Introspection => "introspection",
Self::PushedAuthorizationRequest(_) => "pushed_authorization_request",
})
@@ -77,6 +93,8 @@ impl OAuthRequest {
match self {
Self::Token(_) | Self::Refresh(_) => StatusCode::OK,
Self::PushedAuthorizationRequest(_) => StatusCode::CREATED,
// Unlike https://datatracker.ietf.org/doc/html/rfc7009#section-2.2, oauth-provider seems to return `204`.
Self::Revocation(_) => StatusCode::NO_CONTENT,
_ => unimplemented!(),
}
}
@@ -100,7 +118,7 @@ pub struct OAuthServerAgent<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
{
server_metadata: OAuthAuthorizationServerMetadata,
pub(crate) server_metadata: OAuthAuthorizationServerMetadata,
client_metadata: OAuthClientMetadata,
dpop_client: DpopClient<T>,
resolver: Arc<OAuthResolver<T, D, H>>,
@@ -110,8 +128,8 @@ where
impl<T, D, H> OAuthServerAgent<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
{
pub fn new(
dpop_key: Key,
@@ -129,51 +147,105 @@ where
)?;
Ok(Self { server_metadata, client_metadata, dpop_client, resolver, keyset })
}
/**
* VERY IMPORTANT ! Always call this to process token responses.
*
* Whenever an OAuth token response is received, we **MUST** verify that the
* "sub" is a DID, whose issuer authority is indeed the server we just
* obtained credentials from. This check is a critical step to actually be
* able to use the "sub" (DID) as being the actual user's identifier.
*/
async fn verify_token_response(&self, token_response: OAuthTokenResponse) -> Result<TokenSet> {
// ATPROTO requires that the "sub" is always present in the token response.
let Some(sub) = &token_response.sub else {
pub async fn revoke(&self, token: &str) -> Result<()> {
self.request::<()>(OAuthRequest::Revocation(RevocationRequestParameters {
token: token.into(),
}))
.await?;
Ok(())
}
pub async fn exchange_code(&self, code: &str, verifier: &str) -> Result<TokenSet> {
let token_response = self
.request::<OAuthTokenResponse>(OAuthRequest::Token(TokenRequestParameters {
grant_type: TokenGrantType::AuthorizationCode,
code: code.into(),
redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ?
code_verifier: verifier.into(),
}))
.await?;
let Some(sub) = token_response.sub else {
return Err(Error::Token("missing `sub` in token response".into()));
};
let (metadata, identity) = self.resolver.resolve_from_identity(sub).await?;
if metadata.issuer != self.server_metadata.issuer {
return Err(Error::Token("issuer mismatch".into()));
}
let sub = sub.parse().map_err(Error::InvalidDid)?;
// /!\ IMPORTANT /!\
//
// The token_response MUST always be valid before the "sub" it contains
// can be trusted (see Atproto's OAuth spec for details).
let aud = self.verify_issuer(&sub).await?;

let expires_at = token_response.expires_in.and_then(|expires_in| {
Datetime::now()
.as_ref()
.checked_add_signed(TimeDelta::seconds(expires_in))
.map(Datetime::new)
});
Ok(TokenSet {
sub: sub.clone(),
aud: identity.pds,
iss: metadata.issuer,
iss: self.server_metadata.issuer.clone(),
sub,
aud,
scope: token_response.scope,
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
token_type: token_response.token_type,
expires_at,
})
}
pub async fn exchange_code(&self, code: &str, verifier: &str) -> Result<TokenSet> {
self.verify_token_response(
self.request(OAuthRequest::Token(TokenRequestParameters {
grant_type: TokenGrantType::AuthorizationCode,
code: code.into(),
redirect_uri: self.client_metadata.redirect_uris[0].clone(), // ?
code_verifier: verifier.into(),
pub async fn refresh(&self, token_set: &TokenSet) -> Result<TokenSet> {
let Some(refresh_token) = token_set.refresh_token.as_ref() else {
return Err(Error::TokenRefresh);
};

// /!\ IMPORTANT /!\
//
// The "sub" MUST be a DID, whose issuer authority is indeed the server we
// are trying to obtain credentials from. Note that we are doing this
// *before* we actually try to refresh the token:
// 1) To avoid unnecessary refresh
// 2) So that the refresh is the last async operation, ensuring as few
// async operations happen before the result gets a chance to be stored.
let aud = self.verify_issuer(&token_set.sub).await?;

let response = self
.request::<OAuthTokenResponse>(OAuthRequest::Refresh(RefreshRequestParameters {
grant_type: TokenGrantType::RefreshToken,
refresh_token: refresh_token.clone(),
scope: None,
}))
.await?,
)
.await
.await?;

let expires_at = response.expires_in.and_then(|expires_in| {
Datetime::now()
.as_ref()
.checked_add_signed(TimeDelta::seconds(expires_in))
.map(Datetime::new)
});
Ok(TokenSet {
iss: self.server_metadata.issuer.clone(),
sub: token_set.sub.clone(),
aud,
scope: response.scope,
access_token: response.access_token,
refresh_token: response.refresh_token,
token_type: response.token_type,
expires_at,
})
}
/**
* VERY IMPORTANT ! Always call this to process token responses.
*
* Whenever an OAuth token response is received, we **MUST** verify that the
* "sub" is a DID, whose issuer authority is indeed the server we just
* obtained credentials from. This check is a critical step to actually be
* able to use the "sub" (DID) as being the actual user's identifier.
*
* @returns The user's PDS URL (the resource server for the user)
*/
async fn verify_issuer(&self, sub: &Did) -> Result<String> {
let (metadata, identity) = self.resolver.resolve_from_identity(sub).await?;
if metadata.issuer != self.server_metadata.issuer {
return Err(Error::Token("issuer mismatch".into()));
}
Ok(identity.pds)
}
pub async fn request<O>(&self, request: OAuthRequest) -> Result<O>
where
@@ -185,6 +257,7 @@ where
let body = match &request {
OAuthRequest::Token(params) => self.build_body(params)?,
OAuthRequest::Refresh(params) => self.build_body(params)?,
OAuthRequest::Revocation(params) => self.build_body(params)?,
OAuthRequest::PushedAuthorizationRequest(params) => self.build_body(params)?,
_ => unimplemented!(),
};
@@ -195,7 +268,13 @@ where
.body(body.into_bytes())?;
let res = self.dpop_client.send_http(req).await.map_err(Error::HttpClient)?;
if res.status() == request.expected_status() {
Ok(serde_json::from_slice(res.body())?)
let body = res.body();
if body.is_empty() {
// since an empty body cannot be deserialized, use “null” temporarily to allow deserialization to `()`.
Ok(serde_json::from_slice(b"null")?)
} else {
Ok(serde_json::from_slice(body)?)
}
} else if res.status().is_client_error() {
Err(Error::HttpStatusWithBody(res.status(), serde_json::from_slice(res.body())?))
} else {
@@ -221,7 +300,7 @@ where
Some("private_key_jwt")
if method_supported
.as_ref()
.map_or(false, |v| v.contains(&String::from("private_key_jwt"))) =>
.is_some_and(|v| v.contains(&String::from("private_key_jwt"))) =>
{
if let Some(keyset) = &self.keyset {
let mut algs = self
@@ -258,9 +337,7 @@ where
}
}
Some("none")
if method_supported
.as_ref()
.map_or(false, |v| v.contains(&String::from("none"))) =>
if method_supported.as_ref().is_some_and(|v| v.contains(&String::from("none"))) =>
{
return Ok((None, None))
}
@@ -273,11 +350,26 @@ where
OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => {
Some(&self.server_metadata.token_endpoint)
}
OAuthRequest::Revocation => self.server_metadata.revocation_endpoint.as_ref(),
OAuthRequest::Revocation(_) => self.server_metadata.revocation_endpoint.as_ref(),
OAuthRequest::Introspection => self.server_metadata.introspection_endpoint.as_ref(),
OAuthRequest::PushedAuthorizationRequest(_) => {
self.server_metadata.pushed_authorization_request_endpoint.as_ref()
}
}
}
}

impl<T, D, H> Clone for OAuthServerAgent<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
server_metadata: self.server_metadata.clone(),
client_metadata: self.client_metadata.clone(),
dpop_client: self.dpop_client.clone(),
resolver: Arc::clone(&self.resolver),
keyset: self.keyset.clone(),
}
}
}
69 changes: 69 additions & 0 deletions atrium-oauth/oauth-client/src/server_agent/factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use super::{OAuthServerAgent, Result};
use crate::{
keyset::Keyset,
resolver::OAuthResolver,
types::{OAuthAuthorizationServerMetadata, OAuthClientMetadata},
};
use atrium_api::{
did_doc::DidDocument,
types::string::{Did, Handle},
};
use atrium_common::resolver::Resolver;
use atrium_identity::Error;
use atrium_xrpc::HttpClient;
use jose_jwk::Key;
use std::sync::Arc;

pub struct OAuthServerFactory<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
{
client_metadata: OAuthClientMetadata,
resolver: Arc<OAuthResolver<T, D, H>>,
http_client: Arc<T>,
keyset: Option<Keyset>,
}

impl<T, D, H> OAuthServerFactory<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
{
pub fn new(
client_metadata: OAuthClientMetadata,
resolver: Arc<OAuthResolver<T, D, H>>,
http_client: Arc<T>,
keyset: Option<Keyset>,
) -> Self {
OAuthServerFactory { client_metadata, resolver, http_client, keyset }
}
}

impl<T, D, H> OAuthServerFactory<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = Error> + Send + Sync,
{
pub async fn build_from_issuer(
&self,
dpop_key: Key,
issuer: impl AsRef<str>,
) -> Result<OAuthServerAgent<T, D, H>> {
let server_metadata = self.resolver.get_authorization_server_metadata(&issuer).await?;
self.build_from_metadata(dpop_key, server_metadata)
}
pub fn build_from_metadata(
&self,
dpop_key: Key,
server_metadata: OAuthAuthorizationServerMetadata,
) -> Result<OAuthServerAgent<T, D, H>> {
OAuthServerAgent::new(
dpop_key,
server_metadata,
self.client_metadata.clone(),
Arc::clone(&self.resolver),
Arc::clone(&self.http_client),
self.keyset.clone(),
)
}
}
21 changes: 2 additions & 19 deletions atrium-oauth/oauth-client/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
pub mod memory;
pub mod session;
pub mod session_registry;
pub mod state;

use std::error::Error;
use std::future::Future;
use std::hash::Hash;

#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait SimpleStore<K, V>
where
K: Eq + Hash,
V: Clone,
{
type Error: Error + Send + Sync + 'static;

fn get(&self, key: &K) -> impl Future<Output = Result<Option<V>, Self::Error>>;
fn set(&self, key: K, value: V) -> impl Future<Output = Result<(), Self::Error>>;
fn del(&self, key: &K) -> impl Future<Output = Result<(), Self::Error>>;
fn clear(&self) -> impl Future<Output = Result<(), Self::Error>>;
}
45 changes: 0 additions & 45 deletions atrium-oauth/oauth-client/src/store/memory.rs

This file was deleted.

17 changes: 17 additions & 0 deletions atrium-oauth/oauth-client/src/store/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use crate::types::TokenSet;
use atrium_api::types::string::Did;
use atrium_common::store::{memory::MemoryStore, Store};
use jose_jwk::Key;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Session {
pub dpop_key: Key,
pub token_set: TokenSet,
}

pub trait SessionStore: Store<Did, Session> {}

pub type MemorySessionStore = MemoryStore<Did, Session>;

impl SessionStore for MemorySessionStore {}
299 changes: 299 additions & 0 deletions atrium-oauth/oauth-client/src/store/session_registry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
use crate::{
server_agent::OAuthServerFactory,
store::session::{Session, SessionStore},
};
use atrium_api::{
did_doc::DidDocument,
types::string::{Datetime, Did, Handle},
};
use atrium_common::resolver::Resolver;
use atrium_xrpc::HttpClient;
use dashmap::DashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Mutex;

#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
ServerAgent(#[from] crate::server_agent::Error),
#[error("session store error: {0}")]
Store(String),
#[error("session does not exist")]
SessionNotFound,
}

pub struct SessionRegistry<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
store: Arc<S>,
server_factory: Arc<OAuthServerFactory<T, D, H>>,
pending: DashMap<Did, Arc<Mutex<()>>>,
}

impl<S, T, D, H> SessionRegistry<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
{
pub fn new(store: S, server_factory: Arc<OAuthServerFactory<T, D, H>>) -> Self {
let store = Arc::new(store);
Self { store: Arc::clone(&store), server_factory, pending: DashMap::new() }
}
}

impl<S, T, D, H> SessionRegistry<S, T, D, H>
where
S: SessionStore + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
D: Resolver<Input = Did, Output = DidDocument, Error = atrium_identity::Error> + Send + Sync,
H: Resolver<Input = Handle, Output = Did, Error = atrium_identity::Error> + Send + Sync,
{
async fn get_refreshed(&self, key: &Did) -> Result<Session, Error> {
let lock =
self.pending.entry(key.clone()).or_insert_with(|| Arc::new(Mutex::new(()))).clone();
let _guard = lock.lock().await;

let mut session = self
.store
.get(key)
.await
.map_err(|e| Error::Store(e.to_string()))?
.ok_or(Error::SessionNotFound)?;
if let Some(expires_at) = &session.token_set.expires_at {
if expires_at > &Datetime::now() {
return Ok(session);
}
}

let server = self
.server_factory
.build_from_issuer(session.dpop_key.clone(), &session.token_set.iss)
.await?;
session.token_set = server.refresh(&session.token_set).await?;
self.store
.set(key.clone(), session.clone())
.await
.map_err(|e| Error::Store(e.to_string()))?;
Ok(session)
}
pub async fn get(&self, key: &Did, refresh: bool) -> Result<Session, Error> {
if refresh {
self.get_refreshed(key).await
} else {
// TODO: cached?
self.store
.get(key)
.await
.map_err(|e| Error::Store(e.to_string()))?
.ok_or(Error::SessionNotFound)
}
}
pub async fn set(&self, key: Did, value: Session) -> Result<(), S::Error> {
self.store.set(key.clone(), value.clone()).await
}
pub async fn del(&self, key: &Did) -> Result<(), S::Error> {
self.store.del(key).await
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
tests::{
client_metadata, dpop_key, oauth_resolver, protected_resource_metadata,
server_metadata, MockDidResolver, NoopHandleResolver,
},
types::{OAuthTokenResponse, OAuthTokenType, RefreshRequestParameters, TokenSet},
};
use atrium_common::store::Store;
use atrium_xrpc::http::{header::CONTENT_TYPE, Request, Response, StatusCode};
use std::{collections::HashMap, time::Duration};
use tokio::{sync::Mutex, time::sleep};

#[derive(Error, Debug)]
enum MockStoreError {}

struct MockHttpClient {
next_token: Arc<Mutex<Option<OAuthTokenResponse>>>,
}

impl Default for MockHttpClient {
fn default() -> Self {
Self {
next_token: Arc::new(Mutex::new(Some(OAuthTokenResponse {
access_token: String::from("new_accesstoken"),
token_type: OAuthTokenType::DPoP,
expires_in: Some(10),
refresh_token: Some(String::from("new_refreshtoken")),
scope: None,
sub: None,
}))),
}
}
}

impl HttpClient for MockHttpClient {
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
println!("{:?}", request);

Ok(match (request.uri().host(), request.uri().path()) {
(Some("iss.example.com"), "/.well-known/oauth-authorization-server") => {
Response::builder()
.header(CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&server_metadata())?)?
}
(Some("aud.example.com"), "/.well-known/oauth-protected-resource") => {
Response::builder()
.header(CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&protected_resource_metadata())?)?
}
(Some("iss.example.com"), "/token") => {
let parameters =
serde_html_form::from_bytes::<RefreshRequestParameters>(request.body())?;
if let Some(token_response) = if parameters.refresh_token == "refreshtoken" {
self.next_token.lock().await.take()
} else {
None
} {
Response::builder()
.header(CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&token_response)?)?
} else {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", "DPoP error=\"invalid_token\"")
.body(Vec::new())?
}
}
_ => {
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Vec::new())?
}
})
}
}

struct MockSessionStore {
hm: Mutex<HashMap<Did, Session>>,
}

impl Store<Did, Session> for MockSessionStore {
type Error = MockStoreError;

async fn get(&self, key: &Did) -> Result<Option<Session>, Self::Error> {
sleep(Duration::from_micros(10)).await;
Ok(self.hm.lock().await.get(key).cloned())
}
async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> {
sleep(Duration::from_micros(10)).await;
self.hm.lock().await.insert(key, value);
Ok(())
}
async fn del(&self, key: &Did) -> Result<(), Self::Error> {
sleep(Duration::from_micros(10)).await;
self.hm.lock().await.remove(key);
Ok(())
}
async fn clear(&self) -> Result<(), Self::Error> {
unimplemented!()
}
}

impl SessionStore for MockSessionStore {}

impl Default for MockSessionStore {
fn default() -> Self {
Self { hm: Mutex::new(HashMap::from_iter([(did(), session())])) }
}
}

fn did() -> Did {
"did:fake:handle.test".parse().expect("invalid did")
}

fn session() -> Session {
let dpop_key = dpop_key();
let token_set = TokenSet {
iss: String::from("https://iss.example.com"),
sub: did(),
aud: String::from("https://aud.example.com"),
scope: None,
refresh_token: Some(String::from("refreshtoken")),
access_token: String::from("accesstoken"),
token_type: OAuthTokenType::DPoP,
expires_at: None,
};
Session { dpop_key, token_set }
}

fn session_registry(
store: MockSessionStore,
) -> SessionRegistry<MockSessionStore, MockHttpClient, MockDidResolver, NoopHandleResolver>
{
let http_client = Arc::new(MockHttpClient::default());
SessionRegistry::new(
store,
Arc::new(OAuthServerFactory::new(
client_metadata(),
Arc::new(oauth_resolver(Arc::clone(&http_client))),
http_client,
None,
)),
)
}

#[tokio::test]
async fn test_get_session() -> Result<(), Box<dyn std::error::Error>> {
let registry = session_registry(MockSessionStore::default());
let result = registry.get(&"did:fake:nonexistent".parse()?, false).await;
assert!(matches!(result, Err(Error::SessionNotFound)));
let result = registry.get(&"did:fake:handle.test".parse()?, false).await;
let session = result.expect("handle should exist");
assert_eq!(session.token_set.access_token, "accesstoken");
Ok(())
}

#[tokio::test]
async fn test_get_refreshed() -> Result<(), Box<dyn std::error::Error>> {
let registry = session_registry(MockSessionStore::default());
let session = registry.get(&did(), true).await?;
assert_eq!(session.token_set.access_token, "new_accesstoken");
assert_eq!(session.token_set.refresh_token.as_deref(), Some("new_refreshtoken"));
// second time should return the same session
let session = registry.get(&did(), true).await?;
assert_eq!(session.token_set.access_token, "new_accesstoken");
assert_eq!(session.token_set.refresh_token.as_deref(), Some("new_refreshtoken"));
Ok(())
}

#[tokio::test]
async fn test_get_refreshed_parallel() -> Result<(), Box<dyn std::error::Error>> {
let registry = Arc::new(session_registry(MockSessionStore::default()));
let mut handles = Vec::new();
for _ in 0..3 {
let registry = Arc::clone(&registry);
handles.push(tokio::spawn(async move { registry.get(&did(), true).await }));
}
for result in futures::future::join_all(handles).await {
match result? {
Ok(session) => {
assert_eq!(session.token_set.access_token, "new_accesstoken");
assert_eq!(
session.token_set.refresh_token.as_deref(),
Some("new_refreshtoken")
);
}
Err(err) => {
panic!("unexpected error: {err:?}");
}
}
}
Ok(())
}
}
8 changes: 4 additions & 4 deletions atrium-oauth/oauth-client/src/store/state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::memory::MemorySimpleStore;
use super::SimpleStore;
use atrium_common::store::{memory::MemoryStore, Store};
use jose_jwk::Key;
use serde::{Deserialize, Serialize};

@@ -8,10 +7,11 @@ pub struct InternalStateData {
pub iss: String,
pub dpop_key: Key,
pub verifier: String,
pub app_state: Option<String>,
}

pub trait StateStore: SimpleStore<String, InternalStateData> {}
pub trait StateStore: Store<String, InternalStateData> {}

pub type MemoryStateStore = MemorySimpleStore<String, InternalStateData>;
pub type MemoryStateStore = MemoryStore<String, InternalStateData>;

impl StateStore for MemoryStateStore {}
14 changes: 5 additions & 9 deletions atrium-oauth/oauth-client/src/types.rs
Original file line number Diff line number Diff line change
@@ -4,17 +4,13 @@ mod request;
mod response;
mod token;

pub use self::client_metadata::*;
pub use self::metadata::*;
pub use self::request::*;
pub use self::response::*;
pub use self::token::*;
use crate::atproto::{KnownScope, Scope};
pub use client_metadata::{OAuthClientMetadata, TryIntoOAuthClientMetadata};
pub use metadata::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata};
pub use request::{
AuthorizationCodeChallengeMethod, AuthorizationResponseType,
PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType,
TokenRequestParameters,
};
pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse};
use serde::Deserialize;
pub use token::TokenSet;

#[derive(Debug, Deserialize)]
pub enum AuthorizeOptionPrompt {
2 changes: 1 addition & 1 deletion atrium-oauth/oauth-client/src/types/client_metadata.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use crate::keyset::Keyset;
use jose_jwk::JwkSet;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
pub struct OAuthClientMetadata {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
4 changes: 2 additions & 2 deletions atrium-oauth/oauth-client/src/types/metadata.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use atrium_api::types::string::Language;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
pub struct OAuthAuthorizationServerMetadata {
// https://datatracker.ietf.org/doc/html/rfc8414#section-2
pub issuer: String,
@@ -50,7 +50,7 @@ pub struct OAuthAuthorizationServerMetadata {

// https://datatracker.ietf.org/doc/draft-ietf-oauth-resource-metadata/
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-08#section-2
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
pub struct OAuthProtectedResourceMetadata {
pub resource: String,
pub authorization_servers: Option<Vec<String>>,
28 changes: 16 additions & 12 deletions atrium-oauth/oauth-client/src/types/request.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use serde::Serialize;
use serde::{Deserialize, Serialize};

#[allow(dead_code)]
#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthorizationResponseType {
Code,
@@ -10,8 +9,7 @@ pub enum AuthorizationResponseType {
IdToken,
}

#[allow(dead_code)]
#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthorizationResponseMode {
Query,
@@ -20,15 +18,14 @@ pub enum AuthorizationResponseMode {
FormPost,
}

#[allow(dead_code)]
#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
pub enum AuthorizationCodeChallengeMethod {
S256,
#[serde(rename = "plain")]
Plain,
}

#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
pub struct PushedAuthorizationRequestParameters {
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
pub response_type: AuthorizationResponseType,
@@ -45,15 +42,14 @@ pub struct PushedAuthorizationRequestParameters {
pub prompt: Option<String>,
}

#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TokenGrantType {
AuthorizationCode,
#[allow(dead_code)]
RefreshToken,
}

#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
pub struct TokenRequestParameters {
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
pub grant_type: TokenGrantType,
@@ -63,10 +59,18 @@ pub struct TokenRequestParameters {
pub code_verifier: String,
}

#[derive(Serialize)]
#[derive(Serialize, Deserialize)]
pub struct RefreshRequestParameters {
// https://datatracker.ietf.org/doc/html/rfc6749#section-6
pub grant_type: TokenGrantType,
pub refresh_token: String,
pub scope: Option<String>,
}

// https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
#[derive(Serialize, Deserialize)]
pub struct RevocationRequestParameters {
pub token: String,
// ?
// pub token_type_hint: Option<String>,
}
4 changes: 2 additions & 2 deletions atrium-oauth/oauth-client/src/types/token.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use super::response::OAuthTokenType;
use atrium_api::types::string::Datetime;
use atrium_api::types::string::{Datetime, Did};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct TokenSet {
pub iss: String,
pub sub: String,
pub sub: Did,
pub aud: String,
pub scope: Option<String>,

2 changes: 1 addition & 1 deletion atrium-xrpc/src/traits.rs
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ where
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map_or(false, |content_type| content_type.starts_with("application/json"))
.is_some_and(|content_type| content_type.starts_with("application/json"))
{
Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?))
} else {