Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ description = "Helper to build an Eventsource using reqwest"
keywords = ["sse", "eventsource", "reqwest", "stream", "event"]
categories = ["web-programming::http-client", "no-std", "parsing", "asynchronous"]

[[example]]
name = "simple_digest"
path = "examples/simple_digest.rs"

[dependencies]
eventsource-stream = "0.2.3"
reqwest = { version = "0.12.0", default-features = false, features = ["stream"] }
Expand All @@ -22,10 +26,15 @@ nom = "7.1.0"
mime = "0.3.16"
futures-timer = "3.0.2"
thiserror = "1.0.30"
reqwest-middleware = "0.4.2"

[dev-dependencies]
futures = "0.3.5"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
futures-retry = "0.6"
pin-utils = "0.1"
rocket = "0.5.0"
digest_auth = "0.3.1"
async-trait = "0.1.89"
http = "1.3.1"

112 changes: 106 additions & 6 deletions examples/server.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,76 @@
use rocket::http::Status;
use rocket::request::Outcome;
use rocket::request::Request;
use digest_auth::{AuthContext, AuthorizationHeader, HttpMethod};
use rocket::http::{Header, Status};
use rocket::request::{FromRequest, Outcome, Request};
use rocket::response::content::RawHtml;
use rocket::response::stream::{Event, EventStream};
use rocket::response::Responder;
use rocket::tokio::time::{self, Duration};
use rocket::{get, launch, routes};
use rocket::{catch, catchers, get, launch, routes, Response};
use std::time::SystemTime;

const USER_NAME: &'static str = "admin";
const PASSWORD: &'static str = "admin";
pub struct DigestUser(pub String);

#[rocket::async_trait]
impl<'r> FromRequest<'r> for DigestUser {
type Error = ();

async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let header_val = match req.headers().get_one("Authorization") {
Some(h) => h,
None => return Outcome::Error((Status::Unauthorized, ())),
};

let auth_header = match AuthorizationHeader::parse(header_val) {
Ok(h) => h,
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
};

if auth_header.username != USER_NAME {
return Outcome::Error((Status::Unauthorized, ()));
}
let uri = auth_header.uri.clone();
let method = HttpMethod::from(req.method().as_str());

let ctx = AuthContext::new_with_method(
&auth_header.username,
PASSWORD,
&uri,
None::<&[u8]>,
method,
);

let mut computed = auth_header.clone();
computed.digest(&ctx);

if computed.response == auth_header.response {
Outcome::Success(DigestUser(auth_header.username))
} else {
Outcome::Error((Status::Unauthorized, ()))
}
}
}

struct DigestChallengeResponder;

impl<'r> Responder<'r, 'static> for DigestChallengeResponder {
fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'static> {
let header_value =
r#"Digest realm="Restricted", qop="auth", nonce=INSECURE_NONCE", opaque="opaque""#;

Response::build()
.status(Status::Unauthorized)
.header(Header::new("WWW-Authenticate", header_value))
.ok()
}
}

#[catch(401)]
fn unauthorized() -> DigestChallengeResponder {
DigestChallengeResponder
}

pub struct LastEventId(pub usize);

#[rocket::async_trait]
Expand Down Expand Up @@ -45,6 +109,24 @@ fn events(id: LastEventId) -> EventStream![] {
}
}

#[get("/events_auth")]
fn events_auth(id: LastEventId, user: DigestUser) -> EventStream![] {
let _ = user;
let mut id = id.0;
let mut interval = time::interval(Duration::from_secs(2));
EventStream! {
loop {
interval.tick().await;
let unix_time = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
id += 1;
yield Event::data(unix_time.to_string()).id(id.to_string());
}
}
}

#[get("/")]
fn index() -> RawHtml<&'static str> {
RawHtml(
Expand All @@ -56,7 +138,23 @@ Open Console
es.onmessage = (e) => console.log("Message:", e);
es.onerror = (e) => {
console.log("Error:", e);
// es.close();
};
</script>
"#,
)
}

#[get("/auth")]
fn index_auth() -> RawHtml<&'static str> {
RawHtml(
r#"
Open Console
<script>
const es = new EventSource("http://localhost:8000/events_auth");
es.onopen = () => console.log("Connection Open!");
es.onmessage = (e) => console.log("Message:", e);
es.onerror = (e) => {
console.log("Error:", e);
};
</script>
"#,
Expand All @@ -65,5 +163,7 @@ Open Console

#[launch]
fn rocket() -> _ {
rocket::build().mount("/", routes![events, index])
rocket::build()
.mount("/", routes![events, index, events_auth, index_auth])
.register("/", catchers![unauthorized])
}
85 changes: 85 additions & 0 deletions examples/simple_digest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use digest_auth::{AuthContext, HttpMethod, WwwAuthenticateHeader};
use futures::stream::StreamExt;
use http::Extensions;
use reqwest::{Request, Response};
use reqwest_eventsource::{CannotCloneRequestError, Event, RequestBuilderExt};
use reqwest_middleware::{Middleware, Next};

pub struct DigestAuth {
username: String,
password: String,
}

impl DigestAuth {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: username.into(),
password: password.into(),
}
}
}

#[async_trait::async_trait]
impl Middleware for DigestAuth {
async fn handle(
&self,
mut req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response, reqwest_middleware::Error> {
// Send first request
let res = next
.clone()
.run(
req.try_clone().ok_or_else(|| {
reqwest_middleware::Error::Middleware(CannotCloneRequestError.into())
})?,
extensions,
)
.await?;
if res.status() == reqwest::StatusCode::UNAUTHORIZED {
if let Some(response) = res
.headers()
.get("WWW-Authenticate")
.and_then(|header| header.to_str().ok())
.filter(|str| str.starts_with("Digest "))
.and_then(|header| WwwAuthenticateHeader::parse(header).ok())
.and_then(|mut challenge| {
challenge
.respond(&AuthContext::new_with_method(
&self.username,
&self.password,
&req.url().path().to_string(),
None::<&[u8]>,
HttpMethod::from(req.method().as_str()),
))
.ok()
})
.and_then(|response| response.to_header_string().parse().ok())
{
req.headers_mut().insert("Authorization", response);
return next.run(req, extensions).await;
}
}

Ok(res)
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with(DigestAuth::new("admin", "admin")).build();
let mut es = client.get("http://localhost:8000/events_auth").eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => println!("Connection Open!"),
Ok(Event::Message(message)) => println!("Message: {:#?}", message),
Err(err) => {
println!("Error: {}", err);
// es.close();
}
}
}
Ok(())
}
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use core::fmt;
use eventsource_stream::EventStreamError;
use nom::error::Error as NomError;
use reqwest::header::HeaderValue;
use reqwest::Error as ReqwestError;
use reqwest::Response;
use reqwest::StatusCode;
use std::string::FromUtf8Error;
use reqwest_middleware::Error as ReqwestError;

#[cfg(doc)]
use reqwest::RequestBuilder;
Expand Down
58 changes: 55 additions & 3 deletions src/event_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ use futures_core::task::{Context, Poll};
use futures_timer::Delay;
use pin_project_lite::pin_project;
use reqwest::header::{HeaderName, HeaderValue};
use reqwest::{Error as ReqwestError, IntoUrl, RequestBuilder, Response, StatusCode};
use reqwest::{IntoUrl, RequestBuilder, Response, StatusCode};

use reqwest_middleware::{
ClientWithMiddleware, Error as ReqwestError, RequestBuilder as MiddlewareBuilder,
};
use std::time::Duration;

#[cfg(not(target_arch = "wasm32"))]
Expand Down Expand Up @@ -49,7 +53,7 @@ pin_project! {
/// [`RequestBuilder`] and retries requests when they fail.
#[project = EventSourceProjection]
pub struct EventSource {
builder: RequestBuilder,
builder: MiddlewareBuilder,
#[pin]
next_response: Option<ResponseFuture>,
#[pin]
Expand All @@ -66,6 +70,16 @@ pub struct EventSource {
impl EventSource {
/// Wrap a [`RequestBuilder`]
pub fn new(builder: RequestBuilder) -> Result<Self, CannotCloneRequestError> {
let (c, r) = builder.build_split();
Self::new_with_middleware(MiddlewareBuilder::from_parts(
ClientWithMiddleware::new(c, []),
r.map_err(|_| CannotCloneRequestError)?,
))
}
/// Wrap a [`MiddlewareBuilder`]
pub fn new_with_middleware(
builder: MiddlewareBuilder,
) -> Result<Self, CannotCloneRequestError> {
let builder = builder.header(
reqwest::header::ACCEPT,
HeaderValue::from_static("text/event-stream"),
Expand All @@ -88,6 +102,20 @@ impl EventSource {
Self::new(reqwest::Client::new().get(url)).unwrap()
}

pub fn get_with_middleware<
T: IntoUrl,
U: Into<Box<[std::sync::Arc<dyn reqwest_middleware::Middleware + 'static>]>>,
>(
url: T,
middleware: U,
) -> Self {
Self::new_with_middleware(
reqwest_middleware::ClientWithMiddleware::new(reqwest::Client::new(), middleware)
.get(url),
)
.unwrap()
}

/// Close the EventSource stream and stop trying to reconnect
pub fn close(&mut self) {
self.is_closed = true;
Expand Down Expand Up @@ -149,6 +177,30 @@ fn check_response(response: Response) -> Result<Response, Error> {
}
}

pin_project! {
struct MapErr<S>{
#[pin]
stream: S
}
}

impl<S> MapErr<S> {
fn new(stream: S) -> Self {
Self { stream }
}
}

impl<U, S: Stream<Item = Result<U, reqwest::Error>>> futures_core::Stream for MapErr<S> {
type Item = Result<U, ReqwestError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project()
.stream
.poll_next(cx)
.map_err(ReqwestError::Reqwest)
}
}

impl<'a> EventSourceProjection<'a> {
fn clear_fetch(&mut self) {
self.next_response.take();
Expand All @@ -169,7 +221,7 @@ impl<'a> EventSourceProjection<'a> {

fn handle_response(&mut self, res: Response) {
self.last_retry.take();
let mut stream = res.bytes_stream().eventsource();
let mut stream = MapErr::new(res.bytes_stream()).eventsource();
stream.set_last_event_id(self.last_event_id.clone());
self.cur_stream.replace(Box::pin(stream));
}
Expand Down
7 changes: 7 additions & 0 deletions src/reqwest_ext.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::error::CannotCloneRequestError;
use crate::event_source::EventSource;
use reqwest::RequestBuilder;
use reqwest_middleware::RequestBuilder as MiddlewareBuilder;

/// Provides an easy interface to build an [`EventSource`] from a [`RequestBuilder`]
pub trait RequestBuilderExt {
Expand All @@ -13,3 +14,9 @@ impl RequestBuilderExt for RequestBuilder {
EventSource::new(self)
}
}

impl RequestBuilderExt for MiddlewareBuilder {
fn eventsource(self) -> Result<EventSource, CannotCloneRequestError> {
EventSource::new_with_middleware(self)
}
}