Skip to content

Commit e65584d

Browse files
authored
Merge pull request #253 from cognitedata/oauth-support
feat: Add SASL/OAUTHTOKEN support
2 parents 8c20bc9 + 25baeeb commit e65584d

File tree

8 files changed

+130
-26
lines changed

8 files changed

+130
-26
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ tokio = { version = "1.19", default-features = false, features = ["io-util", "ne
3636
tokio-rustls = { version = "0.26", optional = true, default-features = false, features = ["logging", "ring", "tls12"] }
3737
tracing = "0.1"
3838
zstd = { version = "0.13", optional = true }
39-
rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2"]}
39+
rsasl = { version = "2.1", default-features = false, features = ["config_builder", "provider", "plain", "scram-sha-2", "oauthbearer"]}
4040

4141
[dev-dependencies]
4242
assert_matches = "1.5"

src/client/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use error::{Error, Result};
2222

2323
use self::{controller::ControllerClient, partition::UnknownTopicHandling};
2424

25-
pub use crate::connection::{Credentials, SaslConfig};
25+
pub use crate::connection::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};
2626

2727
#[derive(Debug, Error)]
2828
pub enum ProduceError {

src/connection.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ use crate::{
2020
client::metadata_cache::MetadataCache,
2121
};
2222

23-
pub use self::transport::Credentials;
24-
pub use self::transport::SaslConfig;
2523
pub use self::transport::TlsConfig;
24+
pub use self::transport::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};
2625

2726
mod topology;
2827
mod transport;

src/connection/transport.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use tokio::net::TcpStream;
1111
use tokio_rustls::{client::TlsStream, TlsConnector};
1212

1313
mod sasl;
14-
pub use sasl::{Credentials, SaslConfig};
14+
pub use sasl::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};
1515

1616
#[cfg(feature = "transport-tls")]
1717
pub type TlsConfig = Option<Arc<rustls::ClientConfig>>;

src/connection/transport/sasl.rs

Lines changed: 108 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
use std::{fmt::Debug, sync::Arc};
2+
3+
use futures::future::BoxFuture;
4+
use rsasl::{
5+
callback::SessionCallback,
6+
config::SASLConfig,
7+
property::{AuthzId, OAuthBearerKV, OAuthBearerToken},
8+
};
9+
10+
use crate::messenger::SaslError;
11+
112
#[derive(Debug, Clone)]
213
pub enum SaslConfig {
314
/// SASL - PLAIN
@@ -15,6 +26,11 @@ pub enum SaslConfig {
1526
/// # References
1627
/// - <https://datatracker.ietf.org/doc/html/draft-melnikov-scram-sha-512-04>
1728
ScramSha512(Credentials),
29+
/// SASL - OAUTHBEARER
30+
///
31+
/// # References
32+
/// - <https://datatracker.ietf.org/doc/html/rfc7628>
33+
Oauthbearer(OauthBearerCredentials),
1834
}
1935

2036
#[derive(Debug, Clone)]
@@ -30,19 +46,104 @@ impl Credentials {
3046
}
3147

3248
impl SaslConfig {
33-
pub(crate) fn credentials(&self) -> Credentials {
49+
pub(crate) async fn get_sasl_config(&self) -> Result<Arc<SASLConfig>, SaslError> {
3450
match self {
35-
Self::Plain(credentials) => credentials.clone(),
36-
Self::ScramSha256(credentials) => credentials.clone(),
37-
Self::ScramSha512(credentials) => credentials.clone(),
51+
Self::Plain(credentials)
52+
| Self::ScramSha256(credentials)
53+
| Self::ScramSha512(credentials) => Ok(SASLConfig::with_credentials(
54+
None,
55+
credentials.username.clone(),
56+
credentials.password.clone(),
57+
)?),
58+
Self::Oauthbearer(credentials) => {
59+
// Fetch the token first, since that's an async call.
60+
let token = (*credentials.callback)()
61+
.await
62+
.map_err(SaslError::Callback)?;
63+
64+
struct OauthProvider {
65+
authz_id: Option<String>,
66+
bearer_kvs: Vec<(String, String)>,
67+
token: String,
68+
}
69+
70+
// Define a callback that is called while stepping through the SASL client
71+
// to provide necessary data for oauth.
72+
// Since this callback is synchronous, we fetch the token first. Generally
73+
// speaking the SASL process should not take long enough for the token to
74+
// expire, but we do need to check for token expiry each time we authenticate.
75+
impl SessionCallback for OauthProvider {
76+
fn callback(
77+
&self,
78+
_session_data: &rsasl::callback::SessionData,
79+
_context: &rsasl::callback::Context<'_>,
80+
request: &mut rsasl::callback::Request<'_>,
81+
) -> Result<(), rsasl::prelude::SessionError> {
82+
request
83+
.satisfy::<OAuthBearerKV>(
84+
&self
85+
.bearer_kvs
86+
.iter()
87+
.map(|(k, v)| (k.as_str(), v.as_str()))
88+
.collect::<Vec<_>>(),
89+
)?
90+
.satisfy::<OAuthBearerToken>(&self.token)?;
91+
if let Some(authz_id) = &self.authz_id {
92+
request.satisfy::<AuthzId>(authz_id)?;
93+
}
94+
Ok(())
95+
}
96+
}
97+
98+
Ok(SASLConfig::builder()
99+
.with_default_mechanisms()
100+
.with_callback(OauthProvider {
101+
authz_id: credentials.authz_id.clone(),
102+
bearer_kvs: credentials.bearer_kvs.clone(),
103+
token,
104+
})?)
105+
}
38106
}
39107
}
40108

41109
pub(crate) fn mechanism(&self) -> &str {
110+
use rsasl::mechanisms::*;
42111
match self {
43-
Self::Plain { .. } => "PLAIN",
44-
Self::ScramSha256 { .. } => "SCRAM-SHA-256",
45-
Self::ScramSha512 { .. } => "SCRAM-SHA-512",
112+
Self::Plain { .. } => plain::PLAIN.mechanism.as_str(),
113+
Self::ScramSha256 { .. } => scram::SCRAM_SHA256.mechanism.as_str(),
114+
Self::ScramSha512 { .. } => scram::SCRAM_SHA512.mechanism.as_str(),
115+
Self::Oauthbearer { .. } => oauthbearer::OAUTHBEARER.mechanism.as_str(),
46116
}
47117
}
48118
}
119+
120+
type DynError = Box<dyn std::error::Error + Send + Sync>;
121+
122+
/// Callback for fetching an OAUTH token. This should cache tokens and only request a new token
123+
/// when the old is close to expiring.
124+
pub type OauthCallback =
125+
Arc<dyn Fn() -> BoxFuture<'static, Result<String, DynError>> + Send + Sync>;
126+
127+
#[derive(Clone)]
128+
pub struct OauthBearerCredentials {
129+
/// Callback that should return a token that is valid and will remain valid for
130+
/// long enough to complete authentication. This should cache the token and only request
131+
/// a new one when the old is close to expiring.
132+
/// The token must be on [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) format.
133+
pub callback: OauthCallback,
134+
/// ID of a user to impersonate. Can be left as `None` to authenticate using
135+
/// the user for the token returned by `callback`.
136+
pub authz_id: Option<String>,
137+
/// Custom key-value pairs sent as part of the SASL request. Most normal usage
138+
/// can let this be an empty list.
139+
pub bearer_kvs: Vec<(String, String)>,
140+
}
141+
142+
impl Debug for OauthBearerCredentials {
143+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144+
f.debug_struct("OauthBearerCredentials")
145+
.field("authz_id", &self.authz_id)
146+
.field("bearer_kvs", &self.bearer_kvs)
147+
.finish_non_exhaustive()
148+
}
149+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub mod client;
2828
mod connection;
2929

3030
pub use connection::Error as ConnectionError;
31+
pub use messenger::SaslError;
3132

3233
#[cfg(feature = "unstable-fuzzing")]
3334
pub mod messenger;

src/messenger.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ use std::{
1313
use futures::future::BoxFuture;
1414
use parking_lot::Mutex;
1515
use rsasl::{
16-
config::SASLConfig,
1716
mechname::MechanismNameError,
18-
prelude::{Mechname, SessionError},
17+
prelude::{Mechname, SASLError, SessionError},
1918
};
2019
use thiserror::Error;
2120
use tokio::{
@@ -28,6 +27,7 @@ use tokio::{
2827
};
2928
use tracing::{debug, info, warn};
3029

30+
use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
3131
use crate::{
3232
backoff::ErrorOrThrottle,
3333
protocol::{
@@ -48,10 +48,6 @@ use crate::{
4848
client::SaslConfig,
4949
protocol::{api_version::ApiVersionRange, primitives::CompactString},
5050
};
51-
use crate::{
52-
connection::Credentials,
53-
protocol::{messages::ApiVersionsRequest, traits::ReadType},
54-
};
5551

5652
#[derive(Debug)]
5753
struct Response {
@@ -205,6 +201,12 @@ pub enum SaslError {
205201
#[error("Sasl session error: {0}")]
206202
SaslSessionError(#[from] SessionError),
207203

204+
#[error("Invalid SASL config: {0}")]
205+
InvalidConfig(#[from] SASLError),
206+
207+
#[error("Error in user defined callback: {0}")]
208+
Callback(Box<dyn std::error::Error + Send + Sync>),
209+
208210
#[error("unsupported sasl mechanism")]
209211
UnsupportedSaslMechanism,
210212
}
@@ -581,8 +583,7 @@ where
581583
let mechanism = config.mechanism();
582584
let resp = self.sasl_handshake(mechanism).await?;
583585

584-
let Credentials { username, password } = config.credentials();
585-
let config = SASLConfig::with_credentials(None, username, password).unwrap();
586+
let config = config.get_sasl_config().await?;
586587
let sasl = rsasl::prelude::SASLClient::new(config);
587588
let raw_mechanisms = resp.mechanisms.0.unwrap_or_default();
588589
let mechanisms = raw_mechanisms
@@ -604,12 +605,14 @@ where
604605
loop {
605606
let mut to_sent = Cursor::new(Vec::new());
606607
let state = session.step(data_received.as_deref(), &mut to_sent)?;
607-
if !state.is_running() {
608+
609+
if state.has_sent_message() {
610+
let authentication_response =
611+
self.sasl_authentication(to_sent.into_inner()).await?;
612+
data_received = Some(authentication_response.auth_bytes.0);
613+
} else {
608614
break;
609615
}
610-
611-
let authentication_response = self.sasl_authentication(to_sent.into_inner()).await?;
612-
data_received = Some(authentication_response.auth_bytes.0);
613616
}
614617

615618
Ok(())

src/protocol/frame.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ mod tests {
163163

164164
data.set_position(0);
165165
let actual = data.read_message(0).await.unwrap();
166-
assert_eq!(actual, vec![]);
166+
assert!(actual.is_empty())
167167
}
168168

169169
#[tokio::test]
@@ -172,6 +172,6 @@ mod tests {
172172
client.write_message(&[]).await.unwrap();
173173

174174
let actual = server.read_message(0).await.unwrap();
175-
assert_eq!(actual, vec![]);
175+
assert!(actual.is_empty())
176176
}
177177
}

0 commit comments

Comments
 (0)