diff --git a/Cargo.toml b/Cargo.toml index 0503179..08d9bbc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "proton-api-rs" authors = ["Leander Beernaert "] -version = "0.10.2" +version = "0.11.0" edition = "2021" license = "AGPL-3.0-only" description = "Unofficial implemention of proton REST API in rust" @@ -30,6 +30,7 @@ ureq = {version="2.6", optional=true, features=["socks-proxy", "socks"]} default = [] http-ureq = ["dep:ureq"] http-reqwest = ["dep:reqwest"] +async-traits =[] [dependencies.reqwest] version = "0.11" @@ -40,7 +41,6 @@ optional = true [dev-dependencies] env_logger = "0.10" tokio = {version ="1", features = ["full"]} -httpmock = "0.6" go-gpa-server = {path= "go-gpa-server"} [[example]] @@ -53,5 +53,5 @@ required-features = ["http-ureq"] [[test]] name = "session" -required-features = ["http-ureq"] +required-features = ["http-ureq", "http-reqwest"] diff --git a/examples/user_id.rs b/examples/user_id.rs index bd78dd1..92877bd 100644 --- a/examples/user_id.rs +++ b/examples/user_id.rs @@ -1,4 +1,6 @@ -use proton_api_rs::{http, ping_async}; +use proton_api_rs::domain::SecretString; +use proton_api_rs::http::Sequence; +use proton_api_rs::{http, ping}; use proton_api_rs::{Session, SessionType}; pub use tokio; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; @@ -6,7 +8,7 @@ use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; #[tokio::main(worker_threads = 1)] async fn main() { let user_email = std::env::var("PAPI_USER_EMAIL").unwrap(); - let user_password = std::env::var("PAPI_USER_PASSWORD").unwrap(); + let user_password = SecretString::new(std::env::var("PAPI_USER_PASSWORD").unwrap()); let app_version = std::env::var("PAPI_APP_VERSION").unwrap(); let client = http::ClientBuilder::new() @@ -14,15 +16,16 @@ async fn main() { .build::() .unwrap(); - ping_async(&client).await.unwrap(); + ping().do_async(&client).await.unwrap(); - let session = match Session::login_async(&client, &user_email, &user_password, None, None) + let session = match Session::login(&user_email, &user_password, None) + .do_async(&client) .await .unwrap() { SessionType::Authenticated(c) => c, - SessionType::AwaitingTotp(mut t) => { + SessionType::AwaitingTotp(t) => { let mut stdout = tokio::io::stdout(); let mut line_reader = tokio::io::BufReader::new(tokio::io::stdin()).lines(); let session = { @@ -41,13 +44,12 @@ async fn main() { let totp = line.trim_end_matches('\n'); - match t.submit_totp_async(&client, totp).await { + match t.submit_totp(totp).do_async(&client).await { Ok(ac) => { session = Some(ac); break; } - Err((et, e)) => { - t = et; + Err(e) => { eprintln!("Failed to submit totp: {e}"); continue; } @@ -65,8 +67,8 @@ async fn main() { } }; - let user = session.get_user_async(&client).await.unwrap(); + let user = session.get_user().do_async(&client).await.unwrap(); println!("User ID is {}", user.id); - session.logout_async(&client).await.unwrap(); + session.logout().do_async(&client).await.unwrap(); } diff --git a/examples/user_id_sync.rs b/examples/user_id_sync.rs index 7de976f..c256519 100644 --- a/examples/user_id_sync.rs +++ b/examples/user_id_sync.rs @@ -1,4 +1,6 @@ use proton_api_rs::clientv2::{ping, SessionType}; +use proton_api_rs::domain::SecretString; +use proton_api_rs::http::Sequence; use proton_api_rs::{http, Session}; use std::io::{BufRead, Write}; @@ -6,7 +8,7 @@ fn main() { env_logger::init(); let user_email = std::env::var("PAPI_USER_EMAIL").unwrap(); - let user_password = std::env::var("PAPI_USER_PASSWORD").unwrap(); + let user_password = SecretString::new(std::env::var("PAPI_USER_PASSWORD").unwrap()); let app_version = std::env::var("PAPI_APP_VERSION").unwrap(); let client = http::ClientBuilder::new() @@ -15,12 +17,12 @@ fn main() { .build::() .unwrap(); - ping(&client).unwrap(); + ping().do_sync(&client).unwrap(); - let login_result = Session::login(&client, &user_email, &user_password, None, None); + let login_result = Session::login(&user_email, &user_password, None).do_sync(&client); let session = match login_result.unwrap() { SessionType::Authenticated(s) => s, - SessionType::AwaitingTotp(mut t) => { + SessionType::AwaitingTotp(t) => { let mut line_reader = std::io::BufReader::new(std::io::stdin()); let session = { let mut session = None; @@ -38,13 +40,12 @@ fn main() { let totp = line.trim_end_matches('\n'); - match t.submit_totp(&client, totp) { + match t.submit_totp(totp).do_sync(&client) { Ok(ac) => { session = Some(ac); break; } - Err((et, e)) => { - t = et; + Err(e) => { eprintln!("Failed to submit totp: {e}"); continue; } @@ -62,8 +63,8 @@ fn main() { } }; - let user = session.get_user(&client).unwrap(); + let user = session.get_user().do_sync(&client).unwrap(); println!("User ID is {}", user.id); - session.logout(&client).unwrap(); + session.logout().do_sync(&client).unwrap(); } diff --git a/go-gpa-server/build.rs b/go-gpa-server/build.rs index 7dae73f..2c8e6b2 100644 --- a/go-gpa-server/build.rs +++ b/go-gpa-server/build.rs @@ -27,6 +27,7 @@ fn target_path_for_go_lib() -> (PathBuf, PathBuf) { fn build_go_lib(lib_path: &Path) { let mut command = Command::new("go"); + #[cfg(any(target_os= "linux",target_os = "android"))] command.env("CGO_LDFLAGS", "-Wl,--build-id=none"); command.arg("build"); command.arg("-ldflags=-buildid="); diff --git a/go-gpa-server/go/lib.go b/go-gpa-server/go/lib.go index ac7176c..e4f55e5 100644 --- a/go-gpa-server/go/lib.go +++ b/go-gpa-server/go/lib.go @@ -9,6 +9,7 @@ typedef const char cchar_t; import "C" import ( "sync" + "time" "unsafe" "github.com/ProtonMail/go-proton-api/server" @@ -104,6 +105,18 @@ func gpaCreateUser(h int, cuser *C.cchar_t, cpassword *C.cchar_t, outUserID **C. return 0 } +//export gpaSetAuthLife +func gpaSetAuthLife(h int, seconds int) int { + srv := alloc.resolve(h) + if srv == nil { + return -1 + } + + srv.SetAuthLife(time.Duration(seconds) * time.Second) + + return 0 +} + //export CStrFree func CStrFree(ptr *C.char) { C.free(unsafe.Pointer(ptr)) diff --git a/go-gpa-server/src/lib.rs b/go-gpa-server/src/lib.rs index c4ab7f5..22ded0b 100644 --- a/go-gpa-server/src/lib.rs +++ b/go-gpa-server/src/lib.rs @@ -68,6 +68,16 @@ impl Server { )) } } + + pub fn set_auth_timeout(&self, duration: std::time::Duration) -> Result<()> { + unsafe { + if go::gpaSetAuthLife(self.0, duration.as_secs() as i64) < 0 { + return Err("Failed to set auth timeout".to_string()); + } + + Ok(()) + } + } } impl Drop for Server { diff --git a/go-srp/build.rs b/go-srp/build.rs index 3b04330..fd1af67 100644 --- a/go-srp/build.rs +++ b/go-srp/build.rs @@ -74,6 +74,7 @@ fn target_path_for_go_lib(platform: Platform) -> (PathBuf, PathBuf) { fn build_go_lib(lib_path: &Path, platform: Platform) { let mut command = Command::new("go"); + #[cfg(any(target_os= "linux",target_os = "android"))] command.env("CGO_LDFLAGS", "-Wl,--build-id=none"); match platform { Platform::Desktop => {} diff --git a/src/clientv2/client.rs b/src/clientv2/client.rs index a72dc9d..883f913 100644 --- a/src/clientv2/client.rs +++ b/src/clientv2/client.rs @@ -1,30 +1,10 @@ -use crate::http; -use crate::http::Request; +use crate::http::{Request, RequestDesc}; use crate::requests::{CaptchaRequest, Ping}; -pub fn ping(client: &T) -> Result<(), http::Error> { - Ping.execute_sync::(client, &http::DefaultRequestFactory {}) +pub fn ping() -> impl Request { + Ping.to_request() } -pub async fn ping_async(client: &T) -> Result<(), http::Error> { - Ping.execute_async::(client, &http::DefaultRequestFactory {}) - .await -} - -pub fn captcha_get( - client: &T, - token: &str, - force_web: bool, -) -> Result { - CaptchaRequest::new(token, force_web).execute_sync(client, &http::DefaultRequestFactory {}) -} - -pub async fn captcha_get_async( - client: &T, - token: &str, - force_web: bool, -) -> Result { - CaptchaRequest::new(token, force_web) - .execute_async(client, &http::DefaultRequestFactory {}) - .await +pub fn captcha_get(token: &str, force_web: bool) -> impl Request { + CaptchaRequest::new(token, force_web).to_request() } diff --git a/src/clientv2/mod.rs b/src/clientv2/mod.rs index 29456b2..1f3e450 100644 --- a/src/clientv2/mod.rs +++ b/src/clientv2/mod.rs @@ -1,9 +1,7 @@ mod client; -mod request_repeater; mod session; mod totp; pub use client::*; -pub use request_repeater::*; pub use session::*; pub use totp::*; diff --git a/src/clientv2/request_repeater.rs b/src/clientv2/request_repeater.rs deleted file mode 100644 index 3a1e6d3..0000000 --- a/src/clientv2/request_repeater.rs +++ /dev/null @@ -1,223 +0,0 @@ -//! Automatic request repeater based on the expectations Proton has for their clients. - -use crate::domain::{SecretString, UserUid}; -use crate::http::{ - ClientAsync, ClientSync, DefaultRequestFactory, Method, Request, RequestData, RequestFactory, -}; -use crate::requests::{AuthRefreshRequest, UserAuth}; -use crate::{http, SessionRefreshData}; -use secrecy::{ExposeSecret, Secret}; - -pub trait OnAuthRefreshed: Send + Sync { - fn on_auth_refreshed(&self, user: &Secret, token: &SecretString); -} - -pub struct RequestRepeater { - user_auth: parking_lot::RwLock, - on_auth_refreshed: Option>, -} - -impl std::fmt::Debug for RequestRepeater { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "RequestRepeater{{user_auth:{:?} on_auth_refreshed:{}}}", - self.user_auth, - if self.on_auth_refreshed.is_some() { - "Some" - } else { - "None" - } - ) - } -} - -impl RequestRepeater { - pub fn new(user_auth: UserAuth, on_auth_refreshed: Option>) -> Self { - Self { - user_auth: parking_lot::RwLock::new(user_auth), - on_auth_refreshed, - } - } - - fn refresh_auth(&self, client: &C) -> http::Result<()> { - let mut borrow = self.user_auth.write(); - match AuthRefreshRequest::new( - borrow.uid.expose_secret(), - borrow.refresh_token.expose_secret(), - ) - .execute_sync(client, &DefaultRequestFactory {}) - { - Ok(s) => { - *borrow = UserAuth::from_auth_refresh_response(&s); - if let Some(cb) = &self.on_auth_refreshed { - cb.on_auth_refreshed(&borrow.uid, &borrow.access_token); - } - Ok(()) - } - Err(e) => Err(e), - } - } - - async fn refresh_auth_async(&self, client: &C) -> http::Result<()> { - // Have to clone here due to async boundaries. - let user_auth = { self.user_auth.read().clone() }; - match AuthRefreshRequest::new( - user_auth.uid.expose_secret(), - user_auth.refresh_token.expose_secret(), - ) - .execute_async(client, &DefaultRequestFactory {}) - .await - { - Ok(s) => { - let mut borrow = self.user_auth.write(); - *borrow = UserAuth::from_auth_refresh_response(&s); - if let Some(cb) = &self.on_auth_refreshed { - cb.on_auth_refreshed(&borrow.uid, &borrow.access_token); - } - Ok(()) - } - Err(e) => Err(e), - } - } - - pub fn execute( - &self, - client: &C, - request: R, - ) -> http::Result { - match request.execute_sync(client, self) { - Ok(r) => Ok(r), - Err(original_error) => { - if let http::Error::API(api_err) = &original_error { - if api_err.http_code == 401 { - log::debug!("Account session expired, attempting refresh"); - // Session expired/not authorized, try auth refresh. - if let Err(e) = self.refresh_auth(client) { - log::error!("Failed to refresh account {e}"); - return Err(original_error); - } - - // Execute request again - return request.execute_sync(client, self); - } - } - Err(original_error) - } - } - } - - pub async fn execute_async<'a, C: ClientAsync, R: Request + 'a>( - &'a self, - client: &'a C, - request: R, - ) -> http::Result { - match request.execute_async(client, self).await { - Ok(r) => Ok(r), - Err(original_error) => { - if let http::Error::API(api_err) = &original_error { - log::debug!("Account session expired, attempting refresh"); - if api_err.http_code == 401 { - // Session expired/not authorized, try auth refresh. - if let Err(e) = self.refresh_auth_async(client).await { - log::error!("Failed to refresh account {e}"); - return Err(original_error); - } - - // Execute request again - return request.execute_async(client, self).await; - } - } - Err(original_error) - } - } - } - - pub fn get_refresh_data(&self) -> SessionRefreshData { - let borrow = self.user_auth.read(); - SessionRefreshData { - user_uid: borrow.uid.clone(), - token: borrow.refresh_token.clone(), - } - } -} - -impl RequestFactory for RequestRepeater { - fn new_request(&self, method: Method, url: &str) -> RequestData { - let accessor = self.user_auth.read(); - RequestData::new(method, url) - .header(http::X_PM_UID_HEADER, &accessor.uid.expose_secret().0) - .bearer_token(accessor.access_token.expose_secret()) - } -} - -#[cfg(test)] -mod test { - - #[test] - #[cfg(feature = "http-ureq")] - fn request_repeats_with_401() { - use crate::domain::{EventId, SecretString, UserUid}; - use crate::http::X_PM_UID_HEADER; - use crate::requests::{GetLatestEventRequest, UserAuth}; - use crate::RequestRepeater; - use httpmock::prelude::*; - use secrecy::Secret; - - let server = MockServer::start(); - let url = server.base_url(); - - let client = crate::http::ClientBuilder::new() - .allow_http() - .base_url(&url) - .build::() - .unwrap(); - - let repeater = RequestRepeater::new( - UserAuth { - uid: Secret::new(UserUid("test-uid".to_string())), - access_token: SecretString::new("secret-token".to_string()), - refresh_token: SecretString::new("refresh-token".to_string()), - }, - None, - ); - - let expected_latest_event_id = EventId("My_Event_Id".to_string()); - - let latest_event_first_call = server.mock(|when, then| { - when.method(GET) - .path("/core/v4/events/latest") - .header(X_PM_UID_HEADER, "test-uid"); - then.status(401); - }); - - let latest_event_second_call = server.mock(|when, then| { - when.method(GET) - .path("/core/v4/events/latest") - .header(X_PM_UID_HEADER, "User_UID"); - then.status(200) - .body(format!(r#"{{"EventID":"{}"}}"#, expected_latest_event_id.0)); - }); - - let refresh_mock = server.mock(|when, then| { - when.method(POST).path("/auth/v4/refresh"); - - let response = r#"{ - "UID": "User_UID", - "TokenType": "type", - "AccessToken": "access-token", - "RefreshToken": "refresh-token", - "Scope": "Scope" -}"#; - - then.status(200).body(response); - }); - - let latest_event = repeater.execute(&client, GetLatestEventRequest {}).unwrap(); - assert_eq!(latest_event.event_id, expected_latest_event_id); - - latest_event_first_call.assert(); - refresh_mock.assert(); - latest_event_second_call.assert(); - } -} diff --git a/src/clientv2/session.rs b/src/clientv2/session.rs index 9cc32d0..bfebbbf 100644 --- a/src/clientv2/session.rs +++ b/src/clientv2/session.rs @@ -1,16 +1,25 @@ -use crate::clientv2::request_repeater::RequestRepeater; use crate::clientv2::TotpSession; use crate::domain::{ - Event, EventId, HumanVerification, HumanVerificationLoginData, TwoFactorAuth, User, UserUid, + EventId, HumanVerification, HumanVerificationLoginData, SecretString, TwoFactorAuth, User, + UserUid, +}; +use crate::http; +use crate::http::{ + ClientAsync, ClientRequest, ClientRequestBuilder, ClientSync, FromResponse, Request, + RequestDesc, Sequence, StateProducerSequence, X_PM_UID_HEADER, }; -use crate::http::{DefaultRequestFactory, Request}; use crate::requests::{ AuthInfoRequest, AuthInfoResponse, AuthRefreshRequest, AuthRequest, AuthResponse, - GetEventRequest, GetLatestEventRequest, LogoutRequest, TFAStatus, UserAuth, UserInfoRequest, + GetEventRequest, GetLatestEventRequest, LogoutRequest, TFAStatus, TOTPRequest, UserAuth, + UserInfoRequest, }; -use crate::{http, OnAuthRefreshed}; use go_srp::SRPAuth; -use secrecy::Secret; +use secrecy::{ExposeSecret, Secret}; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; +use std::sync::Arc; #[derive(Debug, thiserror::Error)] pub enum LoginError { @@ -46,174 +55,80 @@ pub enum SessionType { /// users. #[derive(Debug)] pub struct Session { - pub(super) repeater: RequestRepeater, + pub(super) user_auth: Arc>, } impl Session { - fn new(user: UserAuth, on_auth_refreshed_cb: Option>) -> Self { + fn new(user: UserAuth) -> Self { Self { - repeater: RequestRepeater::new(user, on_auth_refreshed_cb), + user_auth: Arc::new(parking_lot::RwLock::new(user)), } } - pub fn login( - client: &T, - username: &str, - password: &str, + pub fn login<'a>( + username: &'a str, + password: &'a SecretString, human_verification: Option, - on_auth_refreshed: Option>, - ) -> Result { - let auth_info_response = - AuthInfoRequest { username }.execute_sync::(client, &DefaultRequestFactory {})?; - - let proof = generate_session_proof(username, password, &auth_info_response)?; - - let auth_response = AuthRequest { + ) -> impl Sequence<'a, Output = SessionType, Error = LoginError> + 'a { + let state = State { username, - client_ephemeral: &proof.client_ephemeral, - client_proof: &proof.client_proof, - srp_session: auth_info_response.srp_session.as_ref(), - human_verification, - } - .execute_sync::(client, &DefaultRequestFactory {}) - .map_err(map_human_verification_err)?; - - validate_server_proof(&proof, &auth_response, on_auth_refreshed) - } - - pub async fn login_async( - client: &T, - username: &str, - password: &str, - human_verification: Option, - on_auth_refreshed: Option>, - ) -> Result { - let auth_info_response = AuthInfoRequest { username } - .execute_async::(client, &DefaultRequestFactory {}) - .await?; - - let proof = generate_session_proof(username, password, &auth_info_response)?; - - let auth_response = AuthRequest { - username, - client_ephemeral: &proof.client_ephemeral, - client_proof: &proof.client_proof, - srp_session: auth_info_response.srp_session.as_ref(), - human_verification, - } - .execute_async::(client, &DefaultRequestFactory {}) - .await - .map_err(map_human_verification_err)?; - - validate_server_proof(&proof, &auth_response, on_auth_refreshed) - } - - pub async fn refresh_async( - client: &T, - user_uid: &UserUid, - token: &str, - on_auth_refreshed: Option>, - ) -> http::Result { - let refresh_response = AuthRefreshRequest::new(user_uid, token) - .execute_async(client, &DefaultRequestFactory {}) - .await?; - let user = UserAuth::from_auth_refresh_response(&refresh_response); - Ok(Session::new(user, on_auth_refreshed)) - } - - pub fn refresh( - client: &T, - user_uid: &UserUid, - token: &str, - on_auth_refreshed: Option>, - ) -> http::Result { - let refresh_response = AuthRefreshRequest::new(user_uid, token) - .execute_sync(client, &DefaultRequestFactory {})?; - let user = UserAuth::from_auth_refresh_response(&refresh_response); - Ok(Session::new(user, on_auth_refreshed)) - } - - pub fn get_user(&self, client: &T) -> Result { - let user = self.repeater.execute(client, UserInfoRequest {})?; - Ok(user.user) - } + password, + hv: human_verification, + }; - pub async fn get_user_async( - &self, - client: &T, - ) -> Result { - let user = self - .repeater - .execute_async(client, UserInfoRequest {}) - .await?; - Ok(user.user) + StateProducerSequence::new(state, login_sequence_1) } - pub fn logout(&self, client: &T) -> Result<(), http::Error> { - LogoutRequest {}.execute_sync::(client, &self.repeater) + pub fn submit_totp(&self, code: &str) -> impl Sequence { + self.wrap_request(TOTPRequest::new(code).to_request()) } - pub async fn logout_async(&self, client: &T) -> Result<(), http::Error> { - LogoutRequest {} - .execute_async::(client, &self.repeater) - .await + pub fn refresh<'a>( + user_uid: &'a UserUid, + token: &'a str, + ) -> impl Sequence<'a, Output = Self, Error = http::Error> + 'a { + AuthRefreshRequest::new(user_uid, token) + .to_request() + .map(|r| { + let user = UserAuth::from_auth_refresh_response(r); + Ok(Session::new(user)) + }) } - pub fn get_latest_event(&self, client: &T) -> http::Result { - let r = self.repeater.execute(client, GetLatestEventRequest {})?; - Ok(r.event_id) + pub fn get_user(&self) -> impl Sequence { + self.wrap_request(UserInfoRequest {}.to_request()) + .map(|r| -> Result { Ok(r.user) }) } - pub async fn get_latest_event_async( - &self, - client: &T, - ) -> http::Result { - let r = self - .repeater - .execute_async(client, GetLatestEventRequest {}) - .await?; - Ok(r.event_id) + pub fn logout(&self) -> impl Sequence { + self.wrap_request(LogoutRequest {}.to_request()) } - pub fn get_event(&self, client: &T, id: &EventId) -> http::Result { - self.repeater.execute(client, GetEventRequest::new(id)) + pub fn get_latest_event(&self) -> impl Request { + self.wrap_request(GetLatestEventRequest {}.to_request()) } - pub async fn get_event_async( - &self, - client: &T, - id: &EventId, - ) -> http::Result { - self.repeater - .execute_async(client, GetEventRequest::new(id)) - .await + pub fn get_event(&self, id: &EventId) -> impl Request { + self.wrap_request(GetEventRequest::new(id).to_request()) } pub fn get_refresh_data(&self) -> SessionRefreshData { - self.repeater.get_refresh_data() + let reader = self.user_auth.read(); + SessionRefreshData { + user_uid: reader.uid.clone(), + token: reader.refresh_token.clone(), + } } -} -fn generate_session_proof( - username: &str, - password: &str, - auth_info_response: &AuthInfoResponse, -) -> Result { - SRPAuth::generate( - username, - password, - auth_info_response.version, - &auth_info_response.salt, - &auth_info_response.modulus, - &auth_info_response.server_ephemeral, - ) - .map_err(LoginError::ServerProof) + #[inline(always)] + fn wrap_request(&self, r: R) -> SessionRequest { + SessionRequest(r, self.user_auth.clone()) + } } fn validate_server_proof( proof: &SRPAuth, - auth_response: &AuthResponse, - on_auth_refreshed: Option>, + auth_response: AuthResponse, ) -> Result { if proof.expected_server_proof != auth_response.server_proof { return Err(LoginError::ServerProof( @@ -221,11 +136,12 @@ fn validate_server_proof( )); } + let tfa_enabled = auth_response.tfa.enabled; let user = UserAuth::from_auth_response(auth_response); - let session = Session::new(user, on_auth_refreshed); + let session = Session::new(user); - match auth_response.tfa.enabled { + match tfa_enabled { TFAStatus::None => Ok(SessionType::Authenticated(session)), TFAStatus::Totp => Ok(SessionType::AwaitingTotp(TotpSession(session))), TFAStatus::FIDO2 => Err(LoginError::Unsupported2FA(TwoFactorAuth::FIDO2)), @@ -233,12 +149,172 @@ fn validate_server_proof( } } -fn map_human_verification_err(e: http::Error) -> LoginError { - if let http::Error::API(e) = &e { +fn map_human_verification_err(e: LoginError) -> LoginError { + if let LoginError::Request(http::Error::API(e)) = &e { if let Ok(hv) = e.try_get_human_verification_details() { return LoginError::HumanVerificationRequired(hv); } } - LoginError::from(e) + e +} + +pub struct SessionRequest(R, Arc>); + +impl SessionRequest { + fn refresh_auth(&self) -> impl Sequence<'_, Output = (), Error = http::Error> + '_ { + let reader = self.1.read(); + AuthRefreshRequest::new( + reader.uid.expose_secret(), + reader.refresh_token.expose_secret(), + ) + .to_request() + .map(|resp| { + let mut writer = self.1.write(); + *writer = UserAuth::from_auth_refresh_response(resp); + Ok(()) + }) + } + + async fn exec_async_impl<'a, C: ClientAsync, F: FromResponse>( + &'a self, + client: &'a C, + ) -> Result { + let v = self.build(client); + match client.execute_async::(v).await { + Ok(r) => Ok(r), + Err(original_error) => { + if let http::Error::API(api_err) = &original_error { + if api_err.http_code == 401 { + log::debug!("Account session expired, attempting refresh"); + // Session expired/not authorized, try auth refresh. + if let Err(e) = self.refresh_auth().do_async(client).await { + log::error!("Failed to refresh account {e}"); + return Err(original_error); + } + + // Execute request again + return client.execute_async::(self.build(client)).await; + } + } + Err(original_error) + } + } + } +} + +impl Request for SessionRequest { + type Response = R::Response; + + fn build(&self, builder: &C) -> C::Request { + let r = self.0.build(builder); + let borrow = self.1.read(); + r.header(X_PM_UID_HEADER, borrow.uid.expose_secret().as_str()) + .bearer_token(borrow.access_token.expose_secret()) + } + + fn exec_sync( + &self, + client: &T, + ) -> Result<::Output, http::Error> { + match client.execute::(self.build(client)) { + Ok(r) => Ok(r), + Err(original_error) => { + if let http::Error::API(api_err) = &original_error { + if api_err.http_code == 401 { + log::debug!("Account session expired, attempting refresh"); + // Session expired/not authorized, try auth refresh. + if let Err(e) = self.refresh_auth().do_sync(client) { + log::error!("Failed to refresh account {e}"); + return Err(original_error); + } + + // Execute request again + return client.execute::(self.build(client)); + } + } + Err(original_error) + } + } + } + + #[cfg(not(feature = "async-traits"))] + fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> Pin< + Box< + dyn Future::Output, http::Error>> + 'a, + >, + > { + Box::pin(async move { self.exec_async_impl::(client).await }) + } + + #[cfg(feature = "async-traits")] + async fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> Result<::Output, http::Error> { + self.exec_async_impl::(client).await + } +} + +struct State<'a> { + username: &'a str, + password: &'a SecretString, + hv: Option, +} + +struct LoginState<'a> { + username: &'a str, + proof: SRPAuth, + session: String, + hv: Option, +} + +fn generate_login_state( + state: State, + auth_info_response: AuthInfoResponse, +) -> Result { + let proof = SRPAuth::generate( + state.username, + state.password.expose_secret(), + auth_info_response.version, + &auth_info_response.salt, + &auth_info_response.modulus, + &auth_info_response.server_ephemeral, + ) + .map_err(LoginError::ServerProof)?; + + Ok(LoginState { + username: state.username, + proof, + session: auth_info_response.srp_session, + hv: state.hv, + }) +} + +fn login_sequence_2( + login_state: LoginState, +) -> impl Sequence<'_, Output = SessionType, Error = LoginError> + '_ { + AuthRequest { + username: login_state.username, + client_ephemeral: &login_state.proof.client_ephemeral, + client_proof: &login_state.proof.client_proof, + srp_session: &login_state.session, + human_verification: &login_state.hv, + } + .to_request() + .map(move |auth_response| { + validate_server_proof(&login_state.proof, auth_response).map_err(map_human_verification_err) + }) +} + +fn login_sequence_1(st: State) -> impl Sequence<'_, Output = SessionType, Error = LoginError> + '_ { + AuthInfoRequest { + username: st.username, + } + .to_request() + .map(move |auth_info_response| generate_login_state(st, auth_info_response)) + .state(login_sequence_2) } diff --git a/src/clientv2/totp.rs b/src/clientv2/totp.rs index e77567f..187a034 100644 --- a/src/clientv2/totp.rs +++ b/src/clientv2/totp.rs @@ -1,42 +1,19 @@ use crate::clientv2::Session; use crate::http; -use crate::http::Request; -use crate::requests::TOTPRequest; +use crate::http::Sequence; #[derive(Debug)] pub struct TotpSession(pub(super) Session); impl TotpSession { - pub fn submit_totp( - self, - client: &T, - code: &str, - ) -> Result { - match TOTPRequest::new(code).execute_sync(client, &self.0.repeater) { - Err(e) => Err((self, e)), - Ok(_) => Ok(self.0), - } + pub fn submit_totp(&self, code: &str) -> impl Sequence { + let auth = self.0.user_auth.clone(); + self.0 + .submit_totp(code) + .map(move |_| Ok(Session { user_auth: auth })) } - pub async fn submit_totp_async( - self, - client: &T, - code: &str, - ) -> Result { - match TOTPRequest::new(code) - .execute_async(client, &self.0.repeater) - .await - { - Err(e) => Err((self, e)), - Ok(_) => Ok(self.0), - } - } - - pub fn logout(&self, client: &T) -> http::Result<()> { - self.0.logout(client) - } - - pub async fn logout_async(&self, client: &T) -> http::Result<()> { - self.0.logout_async(client).await + pub fn logout(&self) -> impl Sequence { + self.0.logout() } } diff --git a/src/domain/human_verification.rs b/src/domain/human_verification.rs index cc9866c..995e473 100644 --- a/src/domain/human_verification.rs +++ b/src/domain/human_verification.rs @@ -36,7 +36,7 @@ impl std::fmt::Display for HumanVerificationType { } /// Human Verification data required for Login. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct HumanVerificationLoginData { /// Type of human verification where the code originated from. pub hv_type: HumanVerificationType, diff --git a/src/http/client.rs b/src/http/client.rs new file mode 100644 index 0000000..7ca8d98 --- /dev/null +++ b/src/http/client.rs @@ -0,0 +1,157 @@ +use crate::http::{Proxy, RequestData, Result, DEFAULT_APP_VERSION, DEFAULT_HOST_URL}; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; +use std::time::Duration; + +/// Builder for an http client +#[derive(Debug, Clone)] +pub struct ClientBuilder { + pub(super) app_version: String, + pub(super) base_url: String, + pub(super) request_timeout: Option, + pub(super) connect_timeout: Option, + pub(super) user_agent: String, + pub(super) proxy_url: Option, + pub(super) debug: bool, + pub(super) allow_http: bool, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ClientBuilder { + pub fn new() -> Self { + Self { + app_version: DEFAULT_APP_VERSION.to_string(), + user_agent: "NoClient/0.1.0".to_string(), + base_url: DEFAULT_HOST_URL.to_string(), + request_timeout: None, + connect_timeout: None, + proxy_url: None, + debug: false, + allow_http: false, + } + } + + /// Set the app version for this client e.g.: my-client@1.4.0+beta. + /// Note: The default app version is not guaranteed to be accepted by the proton servers. + pub fn app_version(mut self, version: &str) -> Self { + self.app_version = version.to_string(); + self + } + + /// Set the user agent to be submitted with every request. + pub fn user_agent(mut self, agent: &str) -> Self { + self.user_agent = agent.to_string(); + self + } + + /// Set server's base url. By default the proton API server url is used. + pub fn base_url(mut self, url: &str) -> Self { + self.base_url = url.to_string(); + self + } + + /// Set the full request timeout. By default there is no timeout. + pub fn request_timeout(mut self, duration: Duration) -> Self { + self.request_timeout = Some(duration); + self + } + + /// Set the connection timeout. By default there is no timeout. + pub fn connect_timeout(mut self, duration: Duration) -> Self { + self.connect_timeout = Some(duration); + self + } + + /// Specify proxy URL for the builder. + pub fn with_proxy(mut self, proxy: Proxy) -> Self { + self.proxy_url = Some(proxy); + self + } + + /// Allow http request + pub fn allow_http(mut self) -> Self { + self.allow_http = true; + self + } + + /// Enable request debugging. + pub fn debug(mut self) -> Self { + self.debug = true; + self + } + + pub fn build>( + self, + ) -> std::result::Result { + T::try_from(self) + } +} + +pub trait ClientRequest: Sized { + fn header(self, key: impl AsRef, value: impl AsRef) -> Self; + + fn bearer_token(self, token: impl AsRef) -> Self { + self.header("authorization", format!("Bearer {}", token.as_ref())) + } +} + +pub trait ClientRequestBuilder { + type Request: ClientRequest; + fn new_request(&self, data: &RequestData) -> Self::Request; +} + +/// HTTP Client abstraction Sync. +pub trait ClientSync: ClientRequestBuilder + TryFrom { + fn execute(&self, request: Self::Request) -> Result; +} + +/// HTTP Client abstraction Async. +pub trait ClientAsync: + ClientRequestBuilder + TryFrom +{ + #[cfg(not(feature = "async-traits"))] + fn execute_async( + &self, + request: Self::Request, + ) -> Pin> + '_>>; + + #[cfg(feature = "async-traits")] + async fn execute_async(&self, request: Self::Request) -> Result; +} + +pub trait ResponseBodySync { + type Body: AsRef<[u8]>; + fn get_body(self) -> Result; +} + +pub trait ResponseBodyAsync { + type Body: AsRef<[u8]>; + + #[cfg(not(feature = "async-traits"))] + fn get_body_async(self) -> Pin>>>; + + #[cfg(feature = "async-traits")] + async fn get_body_async(self) -> Result; +} + +pub trait FromResponse { + type Output; + fn from_response_sync(response: T) -> Result; + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + response: T, + ) -> Pin>>>; + + #[cfg(feature = "async-traits")] + async fn from_response_async( + response: T, + ) -> Result; +} diff --git a/src/http/mod.rs b/src/http/mod.rs index 8971d15..36ec496 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,15 +1,7 @@ //! Basic HTTP Protocol abstraction for the Proton API. -use crate::domain::SecretString; use anyhow; -use secrecy::ExposeSecret; -use serde::de::DeserializeOwned; -use serde::Serialize; -use std::collections::HashMap; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::time::Duration; +use std::fmt::Debug; use thiserror::Error; #[cfg(feature = "http-ureq")] @@ -18,6 +10,18 @@ pub mod ureq_client; #[cfg(feature = "http-reqwest")] pub mod reqwest_client; +mod client; +mod proxy; +mod request; +mod response; +mod sequence; + +pub use client::*; +pub use proxy::*; +pub use request::*; +pub use response::*; +pub use sequence::*; + pub(crate) const DEFAULT_HOST_URL: &str = "https://mail.proton.me/api"; pub(crate) const DEFAULT_APP_VERSION: &str = "proton-api-rs"; #[allow(unused)] // it is used by the http implementations @@ -36,52 +40,6 @@ pub enum Method { Patch, } -/// HTTP Request representation. -#[derive(Debug)] -pub struct RequestData { - #[allow(unused)] // Only used by http implementations. - pub(super) method: Method, - #[allow(unused)] // Only used by http implementations. - pub(super) url: String, - pub(super) headers: HashMap, - pub(super) body: Option>, -} - -impl RequestData { - pub fn new(method: Method, url: impl Into) -> Self { - Self { - method, - url: url.into(), - headers: HashMap::new(), - body: None, - } - } - - pub fn header(mut self, key: impl Into, value: impl Into) -> Self { - self.headers.insert(key.into(), value.into()); - self - } - - pub fn bearer_token(self, token: &str) -> Self { - self.header("authorization", format!("Bearer {token}")) - } - - pub fn bytes(mut self, bytes: Vec) -> Self { - self.body = Some(bytes); - self - } - - pub fn json(self, value: impl Serialize) -> Self { - let bytes = serde_json::to_vec(&value).expect("Failed to serialize json"); - self.json_bytes(bytes) - } - - pub fn json_bytes(mut self, bytes: Vec) -> Self { - self.body = Some(bytes); - self.header("Content-Type", "application/json") - } -} - /// Errors that may occur during an HTTP request, mostly related to network. #[derive(Debug, Error)] pub enum Error { @@ -108,265 +66,3 @@ impl From for Error { } pub type Result = std::result::Result; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum ProxyProtocol { - Https, - Socks5, -} - -#[derive(Debug, Clone)] -pub struct ProxyAuth { - pub username: String, - pub password: SecretString, -} - -#[derive(Debug, Clone)] -pub struct Proxy { - pub protocol: ProxyProtocol, - pub auth: Option, - pub url: String, - pub port: u16, -} - -impl Proxy { - pub fn as_url(&self) -> String { - let protocol = match self.protocol { - ProxyProtocol::Https => "https", - ProxyProtocol::Socks5 => "socks5", - }; - - let auth = if let Some(auth) = &self.auth { - format!("{}:{}@", auth.username, auth.password.expose_secret()) - } else { - String::new() - }; - - format!("{protocol}://{auth}{}:{}", self.url, self.port) - } -} - -/// Builder for an http client -#[derive(Debug, Clone)] -pub struct ClientBuilder { - app_version: String, - base_url: String, - request_timeout: Option, - connect_timeout: Option, - user_agent: String, - proxy_url: Option, - debug: bool, - allow_http: bool, -} - -impl Default for ClientBuilder { - fn default() -> Self { - Self::new() - } -} - -impl ClientBuilder { - pub fn new() -> Self { - Self { - app_version: DEFAULT_APP_VERSION.to_string(), - user_agent: "NoClient/0.1.0".to_string(), - base_url: DEFAULT_HOST_URL.to_string(), - request_timeout: None, - connect_timeout: None, - proxy_url: None, - debug: false, - allow_http: false, - } - } - - /// Set the app version for this client e.g.: my-client@1.4.0+beta. - /// Note: The default app version is not guaranteed to be accepted by the proton servers. - pub fn app_version(mut self, version: &str) -> Self { - self.app_version = version.to_string(); - self - } - - /// Set the user agent to be submitted with every request. - pub fn user_agent(mut self, agent: &str) -> Self { - self.user_agent = agent.to_string(); - self - } - - /// Set server's base url. By default the proton API server url is used. - pub fn base_url(mut self, url: &str) -> Self { - self.base_url = url.to_string(); - self - } - - /// Set the full request timeout. By default there is no timeout. - pub fn request_timeout(mut self, duration: Duration) -> Self { - self.request_timeout = Some(duration); - self - } - - /// Set the connection timeout. By default there is no timeout. - pub fn connect_timeout(mut self, duration: Duration) -> Self { - self.connect_timeout = Some(duration); - self - } - - /// Specify proxy URL for the builder. - pub fn with_proxy(mut self, proxy: Proxy) -> Self { - self.proxy_url = Some(proxy); - self - } - - /// Allow http request - pub fn allow_http(mut self) -> Self { - self.allow_http = true; - self - } - - /// Enable request debugging. - pub fn debug(mut self) -> Self { - self.debug = true; - self - } - - pub fn build>( - self, - ) -> std::result::Result { - T::try_from(self) - } -} - -/// Abstraction for request creation, this can enable wrapping of request creations to add -/// session token or other headers. -pub trait RequestFactory { - fn new_request(&self, method: Method, url: &str) -> RequestData; -} - -/// Default request factory, creates basic requests. -#[derive(Copy, Clone)] -pub struct DefaultRequestFactory {} - -impl RequestFactory for DefaultRequestFactory { - fn new_request(&self, method: Method, url: &str) -> RequestData { - RequestData::new(method, url) - } -} - -pub trait ResponseBodySync { - type Body: AsRef<[u8]>; - fn get_body(self) -> Result; -} - -pub trait ResponseBodyAsync { - type Body: AsRef<[u8]>; - fn get_body_async(self) -> Pin>>>; -} - -pub trait FromResponse { - type Output; - fn from_response_sync(response: T) -> Result; - - fn from_response_async( - response: T, - ) -> Pin>>>; -} - -#[derive(Copy, Clone)] -pub struct NoResponse {} - -impl FromResponse for NoResponse { - type Output = (); - - fn from_response_sync(_: T) -> Result { - Ok(()) - } - - fn from_response_async( - _: T, - ) -> Pin>>> { - Box::pin(async { Ok(()) }) - } -} - -pub struct JsonResponse(PhantomData); - -impl FromResponse for JsonResponse { - type Output = T; - - fn from_response_sync(response: R) -> Result { - let body = response.get_body()?; - let r = serde_json::from_slice(body.as_ref())?; - Ok(r) - } - - fn from_response_async( - response: R, - ) -> Pin>>> { - Box::pin(async move { - let body = response.get_body_async().await?; - let r = serde_json::from_slice(body.as_ref())?; - Ok(r) - }) - } -} - -#[derive(Copy, Clone)] -pub struct StringResponse {} - -impl FromResponse for StringResponse { - type Output = String; - - fn from_response_sync(response: R) -> Result { - let body = response.get_body()?; - Ok(String::from_utf8_lossy(body.as_ref()).to_string()) - } - - fn from_response_async( - response: R, - ) -> Pin>>> { - Box::pin(async move { - let body = response.get_body_async().await?; - Ok(String::from_utf8_lossy(body.as_ref()).to_string()) - }) - } -} - -pub trait Request { - type Output: Sized; - type Response: FromResponse; - - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData; - - fn execute_sync( - &self, - client: &T, - factory: &dyn RequestFactory, - ) -> Result { - client.execute(self, factory) - } - - fn execute_async( - &self, - client: &T, - factory: &dyn RequestFactory, - ) -> Pin>>> { - client.execute_async(self, factory) - } -} - -/// HTTP Client abstraction Sync. -pub trait ClientSync: TryFrom { - fn execute( - &self, - request: &R, - factory: &dyn RequestFactory, - ) -> Result; -} - -/// HTTP Client abstraction Async. -pub trait ClientAsync: TryFrom { - fn execute_async( - &self, - request: &R, - factory: &dyn RequestFactory, - ) -> Pin>>>; -} diff --git a/src/http/proxy.rs b/src/http/proxy.rs new file mode 100644 index 0000000..1804d31 --- /dev/null +++ b/src/http/proxy.rs @@ -0,0 +1,39 @@ +use crate::domain::SecretString; +use secrecy::ExposeSecret; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ProxyProtocol { + Https, + Socks5, +} + +#[derive(Debug, Clone)] +pub struct ProxyAuth { + pub username: String, + pub password: SecretString, +} + +#[derive(Debug, Clone)] +pub struct Proxy { + pub protocol: ProxyProtocol, + pub auth: Option, + pub url: String, + pub port: u16, +} + +impl Proxy { + pub fn as_url(&self) -> String { + let protocol = match self.protocol { + ProxyProtocol::Https => "https", + ProxyProtocol::Socks5 => "socks5", + }; + + let auth = if let Some(auth) = &self.auth { + format!("{}:{}@", auth.username, auth.password.expose_secret()) + } else { + String::new() + }; + + format!("{protocol}://{auth}{}:{}", self.url, self.port) + } +} diff --git a/src/http/request.rs b/src/http/request.rs new file mode 100644 index 0000000..7226f16 --- /dev/null +++ b/src/http/request.rs @@ -0,0 +1,112 @@ +use crate::http::{ClientAsync, ClientRequestBuilder, ClientSync, Error, FromResponse, Method}; +use bytes::Bytes; +use serde::Serialize; +use std::collections::HashMap; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +use std::marker::PhantomData; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; + +/// HTTP Request representation. +#[derive(Debug)] +pub struct RequestData { + #[allow(unused)] // Only used by http implementations. + pub(super) method: Method, + #[allow(unused)] // Only used by http implementations. + pub(super) url: String, + pub(super) headers: HashMap, + pub(super) body: Option, +} + +impl RequestData { + pub fn new(method: Method, url: impl Into) -> Self { + Self { + method, + url: url.into(), + headers: HashMap::new(), + body: None, + } + } + + pub fn header(mut self, key: impl Into, value: impl Into) -> Self { + self.headers.insert(key.into(), value.into()); + self + } + + pub fn bearer_token(self, token: &str) -> Self { + self.header("authorization", format!("Bearer {token}")) + } + + pub fn bytes(mut self, bytes: impl Into) -> Self { + self.body = Some(bytes.into()); + self + } + + pub fn json(self, value: impl Serialize) -> Self { + let bytes = serde_json::to_vec(&value).expect("Failed to serialize json"); + self.json_bytes(bytes) + } + + pub fn json_bytes(mut self, bytes: impl Into) -> Self { + self.body = Some(bytes.into()); + self.header("Content-Type", "application/json") + } +} + +pub trait RequestDesc { + type Output: Sized; + type Response: FromResponse; + + fn build(&self) -> RequestData; + + fn to_request(&self) -> RequestWrapper { + let data = self.build(); + RequestWrapper(data, PhantomData) + } +} + +pub struct RequestWrapper(RequestData, PhantomData); + +impl Request for RequestWrapper { + type Response = F; + + fn build(&self, builder: &C) -> C::Request { + builder.new_request(&self.0) + } +} + +#[cfg(not(feature = "async-traits"))] +type RequestFuture<'a, F> = + Pin::Output, Error>> + 'a>>; + +pub trait Request { + type Response: FromResponse; + + fn build(&self, builder: &C) -> C::Request; + + fn exec_sync( + &self, + client: &T, + ) -> Result<::Output, Error> { + client.execute::(self.build(client)) + } + + #[cfg(not(feature = "async-traits"))] + fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> RequestFuture<'a, Self::Response> { + let v = self.build(client); + Box::pin(async move { client.execute_async::(v).await }) + } + + #[cfg(feature = "async-traits")] + async fn exec_async<'a, T: ClientAsync>( + &'a self, + client: &'a T, + ) -> Result<::Output, Error> { + let v = self.build(client); + client.execute_async::(v).await + } +} diff --git a/src/http/reqwest_client.rs b/src/http/reqwest_client.rs index fb6d150..7383b2b 100644 --- a/src/http/reqwest_client.rs +++ b/src/http/reqwest_client.rs @@ -1,11 +1,14 @@ use crate::http::{ - ClientAsync, ClientBuilder, Error, FromResponse, Method, Request, RequestFactory, - ResponseBodyAsync, X_PM_APP_VERSION_HEADER, + ClientAsync, ClientBuilder, ClientRequest, ClientRequestBuilder, Error, FromResponse, Method, + RequestData, ResponseBodyAsync, X_PM_APP_VERSION_HEADER, }; use crate::requests::APIError; use bytes::Bytes; use reqwest; + +#[cfg(not(feature = "async-traits"))] use std::future::Future; +#[cfg(not(feature = "async-traits"))] use std::pin::Pin; #[derive(Debug)] @@ -43,7 +46,7 @@ impl TryFrom for ReqwestClient { builder = builder .min_tls_version(Version::TLS_1_2) - .https_only(true) + .https_only(!value.allow_http) .cookie_store(true) .user_agent(value.user_agent) .default_headers(header_map); @@ -87,28 +90,39 @@ impl From for Error { struct ReqwestResponse(reqwest::Response); +pub struct ReqwestRequest(reqwest::RequestBuilder); + +impl ClientRequest for ReqwestRequest { + fn header(self, key: impl AsRef, value: impl AsRef) -> Self { + Self(self.0.header(key.as_ref(), value.as_ref())) + } +} + impl ResponseBodyAsync for ReqwestResponse { type Body = Bytes; + #[cfg(not(feature = "async-traits"))] fn get_body_async(self) -> Pin>>> { Box::pin(async { let bytes = self.0.bytes().await?; Ok(bytes) }) } + + #[cfg(feature = "async-traits")] + async fn get_body_async(self) -> crate::http::Result { + let bytes = self.0.bytes().await?; + Ok(bytes) + } } -impl ClientAsync for ReqwestClient { - fn execute_async( - &self, - r: &R, - factory: &dyn RequestFactory, - ) -> Pin>>> { - let request = r.build_request(factory); +impl ClientRequestBuilder for ReqwestClient { + type Request = ReqwestRequest; - let final_url = format!("{}/{}", self.base_url, request.url); + fn new_request(&self, data: &RequestData) -> Self::Request { + let final_url = format!("{}/{}", self.base_url, data.url); - let mut rrequest = match request.method { + let mut request = match data.method { Method::Delete => self.client.delete(&final_url), Method::Get => self.client.get(&final_url), Method::Put => self.client.put(&final_url), @@ -117,32 +131,57 @@ impl ClientAsync for ReqwestClient { }; // Set headers. - for (header, value) in &request.headers { - rrequest = rrequest.header(header, value); + for (header, value) in &data.headers { + request = request.header(header, value); } - if let Some(body) = &request.body { - rrequest = rrequest.body(body.to_vec()) + if let Some(body) = &data.body { + request = request.body(body.clone()) } - Box::pin(async move { - let response = rrequest.send().await?; + ReqwestRequest(request) + } +} - let status = response.status().as_u16(); +impl ReqwestClient { + pub async fn direct_exec( + &self, + r: ReqwestRequest, + ) -> crate::http::Result { + let response = r.0.send().await?; + + let status = response.status().as_u16(); + + if status >= 400 { + let body = response + .bytes() + .await + .map_err(|_| Error::API(APIError::new(status)))?; + + return Err(Error::API(APIError::with_status_and_body( + status, + body.as_ref(), + ))); + } - if status >= 400 { - let body = response - .bytes() - .await - .map_err(|_| Error::API(APIError::new(status)))?; + R::from_response_async(ReqwestResponse(response)).await + } +} - return Err(Error::API(APIError::with_status_and_body( - status, - body.as_ref(), - ))); - } +impl ClientAsync for ReqwestClient { + #[cfg(not(feature = "async-traits"))] + fn execute_async( + &self, + r: Self::Request, + ) -> Pin> + '_>> { + Box::pin(async move { self.direct_exec::(r).await }) + } - R::Response::from_response_async(ReqwestResponse(response)).await - }) + #[cfg(feature = "async-traits")] + async fn execute_async( + &self, + request: Self::Request, + ) -> crate::http::Result { + self.direct_exec::(request).await } } diff --git a/src/http/response.rs b/src/http/response.rs new file mode 100644 index 0000000..ecc9650 --- /dev/null +++ b/src/http/response.rs @@ -0,0 +1,92 @@ +use crate::http::{FromResponse, ResponseBodyAsync, ResponseBodySync, Result}; +use serde::de::DeserializeOwned; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +use std::marker::PhantomData; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; + +#[derive(Copy, Clone)] +pub struct NoResponse {} + +impl FromResponse for NoResponse { + type Output = (); + + fn from_response_sync(_: T) -> Result { + Ok(()) + } + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + _: T, + ) -> Pin>>> { + Box::pin(async { Ok(()) }) + } + + #[cfg(feature = "async-traits")] + async fn from_response_async(_: T) -> Result { + Ok(()) + } +} + +pub struct JsonResponse(PhantomData); + +impl FromResponse for JsonResponse { + type Output = T; + + fn from_response_sync(response: R) -> Result { + let body = response.get_body()?; + let r = serde_json::from_slice(body.as_ref())?; + Ok(r) + } + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + response: R, + ) -> Pin>>> { + Box::pin(async move { + let body = response.get_body_async().await?; + let r = serde_json::from_slice(body.as_ref())?; + Ok(r) + }) + } + + #[cfg(feature = "async-traits")] + async fn from_response_async( + response: R, + ) -> Result { + let body = response.get_body_async().await?; + let r = serde_json::from_slice(body.as_ref())?; + Ok(r) + } +} + +#[derive(Copy, Clone)] +pub struct StringResponse {} + +impl FromResponse for StringResponse { + type Output = String; + + fn from_response_sync(response: R) -> Result { + let body = response.get_body()?; + Ok(String::from_utf8_lossy(body.as_ref()).to_string()) + } + + #[cfg(not(feature = "async-traits"))] + fn from_response_async( + response: R, + ) -> Pin>>> { + Box::pin(async move { + let body = response.get_body_async().await?; + Ok(String::from_utf8_lossy(body.as_ref()).to_string()) + }) + } + + #[cfg(feature = "async-traits")] + async fn from_response_async( + response: R, + ) -> Result { + let body = response.get_body_async().await?; + Ok(String::from_utf8_lossy(body.as_ref()).to_string()) + } +} diff --git a/src/http/sequence.rs b/src/http/sequence.rs new file mode 100644 index 0000000..f441ac7 --- /dev/null +++ b/src/http/sequence.rs @@ -0,0 +1,275 @@ +use crate::http::{ClientAsync, ClientSync, Error, FromResponse, Request}; +use std::fmt::Debug; +#[cfg(not(feature = "async-traits"))] +use std::future::Future; +#[cfg(not(feature = "async-traits"))] +use std::pin::Pin; + +#[cfg(not(feature = "async-traits"))] +type SequenceFuture<'a, O, E> = Pin> + 'a>>; + +/// Trait which can be use to link a sequence of request operations. +pub trait Sequence<'a> { + type Output: 'a; + type Error: From + Debug; + + fn do_sync(self, client: &T) -> Result; + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> SequenceFuture<'a, Self::Output, Self::Error>; + + #[cfg(feature = "async-traits")] + async fn do_async(self, client: &'a T) -> Result; + + fn map Result>(self, f: F) -> MapSequence + where + Self: Sized, + E: From + From + Debug, + { + MapSequence { c: self, f } + } + + fn state(self, f: F) -> StateSequence + where + Self: Sized, + SS: Sequence<'a>, + F: FnOnce(Self::Output) -> SS, + >::Error: From + From + Debug, + { + StateSequence { seq: self, f } + } + + fn chain(self, f: F) -> SequenceChain + where + SS: Sequence<'a>, + F: FnOnce(Self::Output) -> Result, + E: From + Debug, + >::Error: From + From + Debug, + Self: Sized, + { + SequenceChain { s: self, f } + } +} + +impl<'a, R: Request + 'a> Sequence<'a> for R +where + ::Output: 'a, +{ + type Output = ::Output; + type Error = Error; + + fn do_sync(self, client: &T) -> Result { + self.exec_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { self.exec_async(client).await }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result<>::Output, >::Error> { + self.exec_async(client).await + } +} + +#[doc(hidden)] +pub struct MapSequence { + c: C, + f: F, +} + +impl<'a, C, O, E, F> Sequence<'a> for MapSequence +where + O: 'a, + C: Sequence<'a> + 'a, + F: FnOnce(C::Output) -> Result + 'a, + E: From + Debug + From, +{ + type Output = O; + type Error = E; + + fn do_sync(self, client: &T) -> Result { + let v = self.c.do_sync(client)?; + let r = (self.f)(v)?; + Ok(r) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { + let v = self.c.do_async(client).await?; + let r = (self.f)(v)?; + Ok(r) + }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let v = self.c.do_async(client).await?; + let r = (self.f)(v)?; + Ok(r) + } +} + +#[doc(hidden)] +pub struct StateSequence { + seq: S, + f: F, +} + +impl<'a, S, SS, F> Sequence<'a> for StateSequence +where + S: Sequence<'a> + 'a, + SS: Sequence<'a>, + >::Error: From<>::Error> + From + Debug, + F: FnOnce(S::Output) -> SS + 'a, +{ + type Output = SS::Output; + type Error = SS::Error; + + fn do_sync(self, client: &T) -> Result { + let state = self.seq.do_sync(client)?; + let ss = (self.f)(state); + ss.do_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { + let state = self.seq.do_async(client).await?; + let ss = (self.f)(state); + ss.do_async(client).await + }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let state = self.seq.do_async(client).await?; + let ss = (self.f)(state); + ss.do_async(client).await + } +} + +#[doc(hidden)] +pub struct StateProducerSequence { + s: S, + f: F, +} + +impl StateProducerSequence { + pub fn new(s: S, f: F) -> Self { + Self { s, f } + } +} + +impl<'a, Seq, S, F> Sequence<'a> for StateProducerSequence +where + Seq: Sequence<'a>, + F: FnOnce(S) -> Seq, +{ + type Output = Seq::Output; + type Error = Seq::Error; + + fn do_sync(self, client: &T) -> Result { + let seq = (self.f)(self.s); + seq.do_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + let seq = (self.f)(self.s); + seq.do_async(client) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let seq = (self.f)(self.s); + seq.do_async(client).await + } +} + +#[doc(hidden)] +pub struct SequenceChain { + s: S, + f: F, +} + +impl<'a, SS, S, E, F> Sequence<'a> for SequenceChain +where + SS: Sequence<'a>, + S: Sequence<'a> + 'a, + F: FnOnce(S::Output) -> Result + 'a, + E: From + Debug, + >::Error: From + From + Debug, +{ + type Output = SS::Output; + type Error = SS::Error; + + fn do_sync(self, client: &T) -> Result { + let v = self.s.do_sync(client)?; + let ss = (self.f)(v)?; + ss.do_sync(client) + } + + #[cfg(not(feature = "async-traits"))] + fn do_async( + self, + client: &'a T, + ) -> Pin> + 'a>> { + Box::pin(async move { + let v = self.s.do_async(client).await?; + let ss = (self.f)(v)?; + ss.do_async(client).await + }) + } + + #[cfg(feature = "async-traits")] + async fn do_async( + self, + client: &'a T, + ) -> Result< + as Sequence<'a>>::Output, + as Sequence<'a>>::Error, + > { + let v = self.s.do_async(client).await?; + let ss = (self.f)(v)?; + ss.do_async(client).await + } +} diff --git a/src/http/ureq_client.rs b/src/http/ureq_client.rs index 6d5873e..48efc8b 100644 --- a/src/http/ureq_client.rs +++ b/src/http/ureq_client.rs @@ -1,7 +1,10 @@ //! UReq HTTP client implementation. -use crate::http::{ClientBuilder, ClientSync, Error, FromResponse, Method, ResponseBodySync}; -use crate::http::{Request, RequestFactory, X_PM_APP_VERSION_HEADER}; +use crate::http::X_PM_APP_VERSION_HEADER; +use crate::http::{ + ClientBuilder, ClientRequest, ClientRequestBuilder, ClientSync, Error, FromResponse, Method, + RequestData, ResponseBodySync, +}; use crate::requests::APIError; use log::debug; use std::io; @@ -116,13 +119,22 @@ impl ResponseBodySync for UReqDebugResponse { } } -impl ClientSync for UReqClient { - fn execute( - &self, - r: &R, - factory: &dyn RequestFactory, - ) -> Result { - let request = r.build_request(factory); +pub struct UReqRequest { + request: ureq::Request, + body: Option, +} + +impl ClientRequest for UReqRequest { + fn header(mut self, key: impl AsRef, value: impl AsRef) -> Self { + self.request = self.request.set(key.as_ref(), value.as_ref()); + self + } +} + +impl ClientRequestBuilder for UReqClient { + type Request = UReqRequest; + + fn new_request(&self, request: &RequestData) -> Self::Request { let final_url = format!("{}/{}", self.base_url, request.url); let mut ureq_request = match request.method { Method::Delete => self.agent.delete(&final_url), @@ -140,16 +152,25 @@ impl ClientSync for UReqClient { ureq_request = ureq_request.set(header, value); } - let ureq_response = if let Some(body) = &request.body { - ureq_request.send_bytes(body)? + Self::Request { + request: ureq_request, + body: request.body.clone(), + } + } +} + +impl ClientSync for UReqClient { + fn execute(&self, request: Self::Request) -> Result { + let ureq_response = if let Some(body) = request.body { + request.request.send_bytes(body.as_ref())? } else { - ureq_request.call()? + request.request.call()? }; if !self.debug { - R::Response::from_response_sync(UReqResponse(ureq_response)) + R::from_response_sync(UReqResponse(ureq_response)) } else { - R::Response::from_response_sync(UReqDebugResponse(ureq_response)) + R::from_response_sync(UReqDebugResponse(ureq_response)) } } } diff --git a/src/lib.rs b/src/lib.rs index 7c2d9b6..2097069 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "async-traits", allow(incomplete_features))] +#![cfg_attr(feature = "async-traits", feature(async_fn_in_trait))] // Enable clippy if our Cargo.toml file asked us to do so. #![cfg_attr(feature = "clippy", feature(plugin))] #![cfg_attr(feature = "clippy", plugin(clippy))] @@ -38,7 +40,8 @@ //! //! Login into a new session async: //! ``` -//! use proton_api_rs::{http, Session, SessionType}; +//! use proton_api_rs::{http, Session, SessionType, http::Sequence}; +//! use proton_api_rs::domain::SecretString; //! async fn example() { //! let client = http::ClientBuilder::new() //! .user_agent("MyUserAgent/0.0.0") @@ -46,25 +49,26 @@ //! .app_version("MyApp@0.1.1") //! .build::().unwrap(); //! -//! let session = match Session::login_async(&client, "my_address@proton.me", "my_proton_password", None, None).await.unwrap(){ +//! let session = match Session::login(&"my_address@proton.me", &SecretString::new("my_proton_password".into()), None).do_async(&client).await.unwrap(){ //! // Session is authenticated, no 2FA verifications necessary. //! SessionType::Authenticated(c) => c, //! // Session needs 2FA TOTP auth. //! SessionType::AwaitingTotp(t) => { -//! t.submit_totp_async(&client, "000000").await.unwrap() +//! t.submit_totp("000000").do_async(&client).await.unwrap() //! } //! }; //! //! // session is now authenticated and can access the rest of the API. //! // ... //! -//! session.logout_async(&client).await.unwrap(); +//! session.logout().do_async(&client).await.unwrap(); //! } //! ``` //! //! Login into a new session sync: //! ``` -//! use proton_api_rs::{Session, http, SessionType}; +//! use proton_api_rs::{Session, http, SessionType, http::Sequence}; +//! use proton_api_rs::domain::SecretString; //! fn example() { //! let client = http::ClientBuilder::new() //! .user_agent("MyUserAgent/0.0.0") @@ -72,25 +76,25 @@ //! .app_version("MyApp@0.1.1") //! .build::().unwrap(); //! -//! let session = match Session::login(&client, "my_address@proton.me", "my_proton_password", None, None).unwrap(){ +//! let session = match Session::login("my_address@proton.me", &SecretString::new("my_proton_password".into()), None).do_sync(&client).unwrap(){ //! // Session is authenticated, no 2FA verifications necessary. //! SessionType::Authenticated(c) => c, //! // Session needs 2FA TOTP auth. //! SessionType::AwaitingTotp(t) => { -//! t.submit_totp(&client, "000000").unwrap() +//! t.submit_totp("000000").do_sync(&client).unwrap() //! } //! }; //! //! // session is now authenticated and can access the rest of the API. //! // ... //! -//! session.logout(&client).unwrap(); +//! session.logout().do_sync(&client).unwrap(); //! } //! ``` //! //! Login using a previous sessions token. //! ``` -//! use proton_api_rs::{http, Session, SessionType}; +//! use proton_api_rs::{http, Session, SessionType, http::Sequence}; //! use proton_api_rs::domain::UserUid; //! //! async fn example() { @@ -102,12 +106,12 @@ //! .app_version("MyApp@0.1.1") //! .build::().unwrap(); //! -//! let session = Session::refresh_async(&client, &user_uid, &user_refresh_token, None).await.unwrap(); +//! let session = Session::refresh(&user_uid, &user_refresh_token).do_async(&client).await.unwrap(); //! //! // session is now authenticated and can access the rest of the API. //! // ... //! -//! session.logout_async(&client).await.unwrap(); +//! session.logout().do_async(&client).await.unwrap(); //! } //! ``` diff --git a/src/requests/auth.rs b/src/requests/auth.rs index 8db3e8c..060c234 100644 --- a/src/requests/auth.rs +++ b/src/requests/auth.rs @@ -1,8 +1,6 @@ use crate::domain::{HumanVerificationLoginData, SecretString, UserUid}; use crate::http; -use crate::http::{ - RequestData, RequestFactory, X_PM_HUMAN_VERIFICATION_TOKEN, X_PM_HUMAN_VERIFICATION_TOKEN_TYPE, -}; +use crate::http::{RequestData, X_PM_HUMAN_VERIFICATION_TOKEN, X_PM_HUMAN_VERIFICATION_TOKEN_TYPE}; use secrecy::Secret; use serde::{Deserialize, Serialize}; use serde_repr::Deserialize_repr; @@ -15,27 +13,25 @@ pub struct AuthInfoRequest<'a> { pub username: &'a str, } -impl<'a> http::Request for AuthInfoRequest<'a> { - type Output = AuthInfoResponse<'a>; +impl<'a> http::RequestDesc for AuthInfoRequest<'a> { + type Output = AuthInfoResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory - .new_request(http::Method::Post, "auth/v4/info") - .json(self) + fn build(&self) -> RequestData { + RequestData::new(http::Method::Post, "auth/v4/info").json(self) } } #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct AuthInfoResponse<'a> { +pub struct AuthInfoResponse { pub version: i64, - pub modulus: Cow<'a, str>, - pub server_ephemeral: Cow<'a, str>, - pub salt: Cow<'a, str>, + pub modulus: String, + pub server_ephemeral: String, + pub salt: String, #[serde(rename = "SRPSession")] - pub srp_session: Cow<'a, str>, + pub srp_session: String, } #[doc(hidden)] @@ -48,17 +44,15 @@ pub struct AuthRequest<'a> { #[serde(rename = "SRPSession")] pub srp_session: &'a str, #[serde(skip)] - pub human_verification: Option, + pub human_verification: &'a Option, } -impl<'a> http::Request for AuthRequest<'a> { - type Output = AuthResponse<'a>; +impl<'a> http::RequestDesc for AuthRequest<'a> { + type Output = AuthResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - let mut request = factory - .new_request(http::Method::Post, "auth/v4") - .json(self); + fn build(&self) -> RequestData { + let mut request = RequestData::new(http::Method::Post, "auth/v4").json(self); if let Some(hv) = &self.human_verification { // repeat submission with x-pm-human-verification-token and x-pm-human-verification-token-type @@ -74,18 +68,18 @@ impl<'a> http::Request for AuthRequest<'a> { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct AuthResponse<'a> { +pub struct AuthResponse { #[serde(rename = "UserID")] - pub user_id: Cow<'a, str>, + pub user_id: String, #[serde(rename = "UID")] - pub uid: Cow<'a, str>, - pub token_type: Option>, - pub access_token: Cow<'a, str>, - pub refresh_token: Cow<'a, str>, - pub server_proof: Cow<'a, str>, - pub scope: Cow<'a, str>, + pub uid: String, + pub token_type: Option, + pub access_token: String, + pub refresh_token: String, + pub server_proof: String, + pub scope: String, #[serde(rename = "2FA")] - pub tfa: TFAInfo<'a>, + pub tfa: TFAInfo, pub password_mode: PasswordMode, } @@ -110,10 +104,10 @@ pub enum TFAStatus { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct TFAInfo<'a> { +pub struct TFAInfo { pub enabled: TFAStatus, #[serde(rename = "FIDO2")] - pub fido2_info: FIDO2Info<'a>, + pub fido2_info: FIDO2Info, } #[doc(hidden)] @@ -129,9 +123,9 @@ pub struct FIDOKey<'a> { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct FIDO2Info<'a> { +pub struct FIDO2Info { pub authentication_options: serde_json::Value, - pub registered_keys: Option>>, + pub registered_keys: Option, } #[doc(hidden)] @@ -177,17 +171,15 @@ impl<'a> TOTPRequest<'a> { } } -impl<'a> http::Request for TOTPRequest<'a> { +impl<'a> http::RequestDesc for TOTPRequest<'a> { type Output = (); type Response = http::NoResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory - .new_request(http::Method::Post, "auth/v4/2fa") - .json(TFAAuth { - two_factor_code: self.code, - fido2: FIDO2Auth::empty(), - }) + fn build(&self) -> RequestData { + RequestData::new(http::Method::Post, "auth/v4/2fa").json(TFAAuth { + two_factor_code: self.code, + fido2: FIDO2Auth::empty(), + }) } } @@ -200,19 +192,19 @@ pub struct UserAuth { } impl UserAuth { - pub fn from_auth_response(auth: &AuthResponse) -> Self { + pub fn from_auth_response(auth: AuthResponse) -> Self { Self { - uid: Secret::new(UserUid(auth.uid.to_string())), - access_token: SecretString::new(auth.access_token.to_string()), - refresh_token: SecretString::new(auth.refresh_token.to_string()), + uid: Secret::new(UserUid(auth.uid)), + access_token: SecretString::new(auth.access_token), + refresh_token: SecretString::new(auth.refresh_token), } } - pub fn from_auth_refresh_response(auth: &AuthRefreshResponse) -> Self { + pub fn from_auth_refresh_response(auth: AuthRefreshResponse) -> Self { Self { - uid: Secret::new(UserUid(auth.uid.to_string())), - access_token: SecretString::new(auth.access_token.to_string()), - refresh_token: SecretString::new(auth.refresh_token.to_string()), + uid: Secret::new(UserUid(auth.uid)), + access_token: SecretString::new(auth.access_token), + refresh_token: SecretString::new(auth.refresh_token), } } } @@ -233,13 +225,13 @@ pub struct AuthRefresh<'a> { #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] -pub struct AuthRefreshResponse<'a> { +pub struct AuthRefreshResponse { #[serde(rename = "UID")] - pub uid: Cow<'a, str>, - pub token_type: Cow<'a, str>, - pub access_token: Cow<'a, str>, - pub refresh_token: Cow<'a, str>, - pub scope: Cow<'a, str>, + pub uid: String, + pub token_type: Option, + pub access_token: String, + pub refresh_token: String, + pub scope: String, } pub struct AuthRefreshRequest<'a> { @@ -253,31 +245,29 @@ impl<'a> AuthRefreshRequest<'a> { } } -impl<'a> http::Request for AuthRefreshRequest<'a> { - type Output = AuthRefreshResponse<'a>; +impl<'a> http::RequestDesc for AuthRefreshRequest<'a> { + type Output = AuthRefreshResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory - .new_request(http::Method::Post, "auth/v4/refresh") - .json(AuthRefresh { - uid: &self.uid.0, - refresh_token: self.token, - grant_type: "refresh_token", - response_type: "token", - redirect_uri: "https://protonmail.ch/", - }) + fn build(&self) -> RequestData { + RequestData::new(http::Method::Post, "auth/v4/refresh").json(AuthRefresh { + uid: &self.uid.0, + refresh_token: self.token, + grant_type: "refresh_token", + response_type: "token", + redirect_uri: "https://protonmail.ch/", + }) } } pub struct LogoutRequest {} -impl http::Request for LogoutRequest { +impl http::RequestDesc for LogoutRequest { type Output = (); type Response = http::NoResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory.new_request(http::Method::Delete, "auth/v4") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Delete, "auth/v4") } } @@ -292,16 +282,17 @@ impl<'a> CaptchaRequest<'a> { } } -impl<'a> http::Request for CaptchaRequest<'a> { +impl<'a> http::RequestDesc for CaptchaRequest<'a> { type Output = String; type Response = http::StringResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { + fn build(&self) -> RequestData { let url = if self.force_web { format!("core/v4/captcha?ForceWebMessaging=1&Token={}", self.token) } else { format!("core/v4/captcha?Token={}", self.token) }; - factory.new_request(http::Method::Get, &url) + + RequestData::new(http::Method::Get, url) } } diff --git a/src/requests/event.rs b/src/requests/event.rs index 7a61f6e..d281f8a 100644 --- a/src/requests/event.rs +++ b/src/requests/event.rs @@ -1,5 +1,5 @@ use crate::http; -use crate::http::{RequestData, RequestFactory}; +use crate::http::RequestData; use serde::Deserialize; #[doc(hidden)] @@ -11,12 +11,12 @@ pub struct LatestEventResponse { pub struct GetLatestEventRequest; -impl http::Request for GetLatestEventRequest { +impl http::RequestDesc for GetLatestEventRequest { type Output = LatestEventResponse; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory.new_request(http::Method::Get, "core/v4/events/latest") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Get, "core/v4/events/latest") } } @@ -30,14 +30,14 @@ impl<'a> GetEventRequest<'a> { } } -impl<'a> http::Request for GetEventRequest<'a> { +impl<'a> http::RequestDesc for GetEventRequest<'a> { type Output = crate::domain::Event; type Response = http::JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> RequestData { - factory.new_request( + fn build(&self) -> RequestData { + RequestData::new( http::Method::Get, - &format!("core/v4/events/{}", self.event_id), + format!("core/v4/events/{}", self.event_id), ) } } diff --git a/src/requests/tests.rs b/src/requests/tests.rs index 031bae5..806a3ca 100644 --- a/src/requests/tests.rs +++ b/src/requests/tests.rs @@ -1,12 +1,13 @@ use crate::http; +use crate::http::RequestData; pub struct Ping; -impl http::Request for Ping { +impl http::RequestDesc for Ping { type Output = (); type Response = http::NoResponse; - fn build_request(&self, factory: &dyn http::RequestFactory) -> http::RequestData { - factory.new_request(http::Method::Get, "tests/ping") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Get, "tests/ping") } } diff --git a/src/requests/user.rs b/src/requests/user.rs index fa22295..45a9101 100644 --- a/src/requests/user.rs +++ b/src/requests/user.rs @@ -1,6 +1,6 @@ use crate::domain::User; use crate::http; -use crate::http::{JsonResponse, RequestFactory}; +use crate::http::{JsonResponse, RequestData}; use serde::Deserialize; #[derive(Deserialize)] @@ -11,11 +11,11 @@ pub struct UserInfoResponse { pub struct UserInfoRequest {} -impl http::Request for UserInfoRequest { +impl http::RequestDesc for UserInfoRequest { type Output = UserInfoResponse; type Response = JsonResponse; - fn build_request(&self, factory: &dyn RequestFactory) -> http::RequestData { - factory.new_request(http::Method::Get, "core/v4/users") + fn build(&self) -> RequestData { + RequestData::new(http::Method::Get, "core/v4/users") } } diff --git a/tests/session/login.rs b/tests/session/login.rs index 5b37338..92adcce 100644 --- a/tests/session/login.rs +++ b/tests/session/login.rs @@ -1,38 +1,126 @@ -use crate::utils::create_session_and_server; +use crate::utils::{create_session_and_server, ClientASync, ClientSync}; +use proton_api_rs::domain::SecretString; +use proton_api_rs::http::Sequence; use proton_api_rs::{http, LoginError, Session, SessionType}; +use secrecy::{ExposeSecret, Secret}; +use tokio; const DEFAULT_USER_EMAIL: &str = "foo@bar.com"; const DEFAULT_USER_PASSWORD: &str = "12345"; #[test] fn session_login() { - let (client, server) = create_session_and_server(); + let (client, server) = create_session_and_server::(); + let (user_id, _) = server .create_user(DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD) .expect("failed to create default user"); let auth_result = Session::login( - &client, DEFAULT_USER_EMAIL, - DEFAULT_USER_PASSWORD, + &Secret::::new(DEFAULT_USER_PASSWORD.to_string()), None, + ) + .do_sync(&client) + .expect("Failed to login"); + + assert!(matches!(auth_result, SessionType::Authenticated(_))); + + if let SessionType::Authenticated(s) = auth_result { + let user = s.get_user().do_sync(&client).expect("Failed to get user"); + assert_eq!(user.id.as_ref(), user_id.as_ref()); + + s.logout().do_sync(&client).expect("Failed to logout") + } +} + +#[test] +fn session_login_auto_refresh() { + let (client, server) = create_session_and_server::(); + + let (user_id, _) = server + .create_user(DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD) + .expect("failed to create default user"); + let auth_result = Session::login( + DEFAULT_USER_EMAIL, + &Secret::::new(DEFAULT_USER_PASSWORD.to_string()), + None, + ) + .do_sync(&client) + .expect("Failed to login"); + + assert!(matches!(auth_result, SessionType::Authenticated(_))); + + if let SessionType::Authenticated(s) = auth_result { + let user = s.get_user().do_sync(&client).expect("Failed to get user"); + assert_eq!(user.id.as_ref(), user_id.as_ref()); + + let rs = s.get_refresh_data(); + server + .set_auth_timeout(std::time::Duration::from_secs(1)) + .expect("Failed to set timeout"); + std::thread::sleep(std::time::Duration::from_secs(1)); + + let user = s.get_user().do_sync(&client).expect("Failed to get user"); + assert_eq!(user.id.as_ref(), user_id.as_ref()); + + let rs_post_refresh = s.get_refresh_data(); + + assert_eq!( + rs.user_uid.expose_secret(), + rs_post_refresh.user_uid.expose_secret() + ); + + assert_ne!( + rs.token.expose_secret(), + rs_post_refresh.token.expose_secret() + ); + + s.logout().do_sync(&client).expect("Failed to logout") + } +} + +#[tokio::test()] +async fn session_login_async() { + let (client, server) = create_session_and_server::(); + + let (user_id, _) = server + .create_user(DEFAULT_USER_EMAIL, DEFAULT_USER_PASSWORD) + .expect("failed to create default user"); + let auth_result = Session::login( + DEFAULT_USER_EMAIL, + &Secret::::new(DEFAULT_USER_PASSWORD.to_string()), None, ) + .do_async(&client) + .await .expect("Failed to login"); assert!(matches!(auth_result, SessionType::Authenticated(_))); if let SessionType::Authenticated(s) = auth_result { - let user = s.get_user(&client).expect("Failed to get user"); + let user = s + .get_user() + .do_async(&client) + .await + .expect("Failed to get user"); assert_eq!(user.id.as_ref(), user_id.as_ref()); - s.logout(&client).expect("Failed to logout") + s.logout() + .do_async(&client) + .await + .expect("Failed to logout") } } #[test] fn session_login_invalid_user() { - let (client, _server) = create_session_and_server(); - let auth_result = Session::login(&client, "bar", DEFAULT_USER_PASSWORD, None, None); + let (client, _server) = create_session_and_server::(); + let auth_result = Session::login( + "bar", + &SecretString::new(DEFAULT_USER_PASSWORD.into()), + None, + ) + .do_sync(&client); assert!(matches!( auth_result, diff --git a/tests/session/utils.rs b/tests/session/utils.rs index 5263622..f9c521b 100644 --- a/tests/session/utils.rs +++ b/tests/session/utils.rs @@ -4,11 +4,13 @@ use proton_api_rs::http; use proton_api_rs::http::ClientBuilder; use std::sync::OnceLock; -type Client = http::ureq_client::UReqClient; +pub type ClientSync = http::ureq_client::UReqClient; +pub type ClientASync = http::reqwest_client::ReqwestClient; static LOG_CELL: OnceLock<()> = OnceLock::new(); -pub fn create_session_and_server() -> (Client, Server) { +pub fn create_session_and_server>( +) -> (Client, Server) { let debug = if let Ok(v) = std::env::var("RUST_LOG") { if v.eq_ignore_ascii_case("debug") { true