diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dab2acd..220ecb0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,6 +60,15 @@ jobs: - name: Check async-std-runtime, async-tls, async-native-tls run: cargo check --features async-std-runtime,async-tls,async-native-tls + - name: Check smol-runtime, async-tls + run: cargo check --features smol-runtime,async-tls + + - name: Check smol-runtime, async-native-tls + run: cargo check --features smol-runtime,async-native-tls + + - name: Check smol-runtime, async-tls, async-native-tls + run: cargo check --features smol-runtime,async-tls,async-native-tls + - name: Check tokio-runtime, tokio-native-tls run: cargo check --features tokio-runtime,tokio-native-tls diff --git a/Cargo.toml b/Cargo.toml index 322e495..2bc2e36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ default = ["handshake", "futures-03-sink"] futures-03-sink = ["futures-util"] handshake = ["tungstenite/handshake"] async-std-runtime = ["async-std", "handshake"] +smol-runtime = ["async-net", "handshake"] tokio-runtime = ["tokio", "handshake"] gio-runtime = ["gio", "glib", "handshake"] async-tls = ["real-async-tls", "handshake"] @@ -34,7 +35,7 @@ url = ["tungstenite/url"] __rustls-tls = ["tokio-runtime", "real-tokio-rustls", "rustls-pki-types", "tungstenite/__rustls-tls"] [package.metadata.docs.rs] -features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-native-tls"] +features = ["async-std-runtime", "smol-runtime", "tokio-runtime", "gio-runtime", "async-tls", "async-native-tls", "tokio-native-tls"] [dependencies] log = "0.4" @@ -58,6 +59,10 @@ default-features = false optional = true version = "1.0" +[dependencies.async-net] +optional = true +version = "2.0" + [dependencies.real-tokio-openssl] optional = true version = "0.6" @@ -136,6 +141,11 @@ http-body-util = "0.1" version = "0.28" features = ["url"] +# For smol examples +[dependencies.smol] +optional = true +version = "2.0" + [[example]] name = "autobahn-client" required-features = ["async-std-runtime"] @@ -144,6 +154,10 @@ required-features = ["async-std-runtime"] name = "async-std-echo" required-features = ["async-std-runtime"] +[[example]] +name = "smol-echo" +required-features = ["smol-runtime", "smol"] + [[example]] name = "client" required-features = ["async-std-runtime"] diff --git a/examples/smol-echo.rs b/examples/smol-echo.rs new file mode 100644 index 0000000..86264d0 --- /dev/null +++ b/examples/smol-echo.rs @@ -0,0 +1,30 @@ +use async_tungstenite::{smol::connect_async, tungstenite::Message}; +use futures::prelude::*; + +async fn run() -> Result<(), Box> { + #[cfg(any(feature = "async-tls", feature = "async-native-tls"))] + let url = "wss://echo.websocket.org/.ws"; + #[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] + let url = "ws://echo.websocket.org/.ws"; + + println!("Connecting: \"{}\"", url); + let (mut ws_stream, _) = connect_async(url).await?; + + let msg = ws_stream.next().await.ok_or("didn't receive anything")??; + println!("Received: {:?}", msg); + + let text = "Hello, World!"; + + println!("Sending: \"{}\"", text); + ws_stream.send(Message::text(text)).await?; + + let msg = ws_stream.next().await.ok_or("didn't receive anything")??; + + println!("Received: {:?}", msg); + + Ok(()) +} + +fn main() -> Result<(), Box> { + smol::block_on(run()) +} diff --git a/src/lib.rs b/src/lib.rs index b759357..5485fee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,6 +90,8 @@ pub mod async_std; pub mod async_tls; #[cfg(feature = "gio-runtime")] pub mod gio; +#[cfg(feature = "smol-runtime")] +pub mod smol; #[cfg(feature = "tokio-runtime")] pub mod tokio; @@ -747,6 +749,7 @@ impl Shared { #[cfg(any( feature = "async-tls", feature = "async-std-runtime", + feature = "smol-runtime", feature = "tokio-runtime", feature = "gio-runtime" ))] @@ -779,6 +782,7 @@ pub(crate) fn domain( #[cfg(any( feature = "async-std-runtime", + feature = "smol-runtime", feature = "tokio-runtime", feature = "gio-runtime" ))] @@ -805,6 +809,7 @@ mod tests { #[cfg(any( feature = "async-tls", feature = "async-std-runtime", + feature = "smol-runtime", feature = "tokio-runtime", feature = "gio-runtime" ))] diff --git a/src/smol.rs b/src/smol.rs new file mode 100644 index 0000000..985b9f5 --- /dev/null +++ b/src/smol.rs @@ -0,0 +1,299 @@ +//! `async-std` integration. +use tungstenite::client::IntoClientRequest; +use tungstenite::handshake::client::{Request, Response}; +use tungstenite::protocol::WebSocketConfig; +use tungstenite::Error; + +use async_net::TcpStream; + +use super::{domain, port, WebSocketStream}; + +#[cfg(feature = "async-native-tls")] +use futures_io::{AsyncRead, AsyncWrite}; + +#[cfg(feature = "async-native-tls")] +pub(crate) mod async_native_tls { + use async_native_tls::TlsConnector as AsyncTlsConnector; + use async_native_tls::TlsStream; + use real_async_native_tls as async_native_tls; + + use tungstenite::client::uri_mode; + use tungstenite::handshake::client::Request; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use futures_io::{AsyncRead, AsyncWrite}; + + use crate::stream::Stream as StreamSwitcher; + use crate::{ + client_async_with_config, domain, IntoClientRequest, Response, WebSocketConfig, + WebSocketStream, + }; + + /// A stream that might be protected with TLS. + pub type MaybeTlsStream = StreamSwitcher>; + + pub type AutoStream = MaybeTlsStream; + + pub type Connector = AsyncTlsConnector; + + async fn wrap_stream( + socket: S, + domain: String, + connector: Option, + mode: Mode, + ) -> Result, Error> + where + S: 'static + AsyncRead + AsyncWrite + Unpin, + { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(socket)), + Mode::Tls => { + let stream = { + let connector = if let Some(connector) = connector { + connector + } else { + AsyncTlsConnector::new() + }; + connector + .connect(&domain, socket) + .await + .map_err(|err| Error::Tls(err.into()))? + }; + Ok(StreamSwitcher::Tls(stream)) + } + } + } + + /// Creates a WebSocket handshake from a request and a stream, + /// upgrading the stream to TLS if required and using the given + /// connector and WebSocket configuration. + pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: IntoClientRequest + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, + { + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + +#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] +pub(crate) mod dummy_tls { + use futures_io::{AsyncRead, AsyncWrite}; + + use tungstenite::client::{uri_mode, IntoClientRequest}; + use tungstenite::handshake::client::Request; + use tungstenite::stream::Mode; + use tungstenite::Error; + + use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream}; + + pub type AutoStream = S; + type Connector = (); + + async fn wrap_stream( + socket: S, + _domain: String, + _connector: Option<()>, + mode: Mode, + ) -> Result, Error> + where + S: 'static + AsyncRead + AsyncWrite + Unpin, + { + match mode { + Mode::Plain => Ok(socket), + Mode::Tls => Err(Error::Url( + tungstenite::error::UrlError::TlsFeatureNotEnabled, + )), + } + } + + /// Creates a WebSocket handshake from a request and a stream, + /// upgrading the stream to TLS if required and using the given + /// connector and WebSocket configuration. + pub async fn client_async_tls_with_connector_and_config( + request: R, + stream: S, + connector: Option, + config: Option, + ) -> Result<(WebSocketStream>, Response), Error> + where + R: IntoClientRequest + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, + { + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + + // Make sure we check domain and mode first. URL must be valid. + let mode = uri_mode(request.uri())?; + + let stream = wrap_stream(stream, domain, connector, mode).await?; + client_async_with_config(request, stream, config).await + } +} + +#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] +pub use self::dummy_tls::client_async_tls_with_connector_and_config; +#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))] +use self::dummy_tls::AutoStream; + +#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))] +pub use crate::async_tls::client_async_tls_with_connector_and_config; +#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))] +use crate::async_tls::AutoStream; +#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))] +type Connector = real_async_tls::TlsConnector; + +#[cfg(feature = "async-native-tls")] +pub use self::async_native_tls::client_async_tls_with_connector_and_config; +#[cfg(feature = "async-native-tls")] +use self::async_native_tls::{AutoStream, Connector}; + +/// Type alias for the stream type of the `client_async()` functions. +pub type ClientStream = AutoStream; + +#[cfg(feature = "async-native-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required. +pub async fn client_async_tls( + request: R, + stream: S, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, None).await +} + +#[cfg(feature = "async-native-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// WebSocket configuration. +pub async fn client_async_tls_with_config( + request: R, + stream: S, + config: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, None, config).await +} + +#[cfg(feature = "async-native-tls")] +/// Creates a WebSocket handshake from a request and a stream, +/// upgrading the stream to TLS if required and using the given +/// connector. +pub async fn client_async_tls_with_connector( + request: R, + stream: S, + connector: Option, +) -> Result<(WebSocketStream>, Response), Error> +where + R: IntoClientRequest + Unpin, + S: 'static + AsyncRead + AsyncWrite + Unpin, + AutoStream: Unpin, +{ + client_async_tls_with_connector_and_config(request, stream, connector, None).await +} + +/// Type alias for the stream type of the `connect_async()` functions. +pub type ConnectStream = ClientStream; + +/// Connect to a given URL. +/// +/// Accepts any request that implements [`IntoClientRequest`], which is often just `&str`, but can +/// be a variety of types such as `httparse::Request` or [`tungstenite::http::Request`] for more +/// complex uses. +/// +/// ```no_run +/// # use tungstenite::client::IntoClientRequest; +/// +/// # async fn test() { +/// use tungstenite::http::{Method, Request}; +/// use async_tungstenite::smol::connect_async; +/// +/// let mut request = "wss://api.example.com".into_client_request().unwrap(); +/// request.headers_mut().insert("api-key", "42".parse().unwrap()); +/// +/// let (stream, response) = connect_async(request).await.unwrap(); +/// # } +/// ``` +pub async fn connect_async( + request: R, +) -> Result<(WebSocketStream, Response), Error> +where + R: IntoClientRequest + Unpin, +{ + connect_async_with_config(request, None).await +} + +/// Connect to a given URL with a given WebSocket configuration. +pub async fn connect_async_with_config( + request: R, + config: Option, +) -> Result<(WebSocketStream, Response), Error> +where + R: IntoClientRequest + Unpin, +{ + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + let port = port(&request)?; + + let try_socket = TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls_with_connector_and_config(request, socket, None, config).await +} + +#[cfg(any(feature = "async-tls", feature = "async-native-tls"))] +/// Connect to a given URL using the provided TLS connector. +pub async fn connect_async_with_tls_connector( + request: R, + connector: Option, +) -> Result<(WebSocketStream, Response), Error> +where + R: IntoClientRequest + Unpin, +{ + connect_async_with_tls_connector_and_config(request, connector, None).await +} + +#[cfg(any(feature = "async-tls", feature = "async-native-tls"))] +/// Connect to a given URL using the provided TLS connector. +pub async fn connect_async_with_tls_connector_and_config( + request: R, + connector: Option, + config: Option, +) -> Result<(WebSocketStream, Response), Error> +where + R: IntoClientRequest + Unpin, +{ + let request: Request = request.into_client_request()?; + + let domain = domain(&request)?; + let port = port(&request)?; + + let try_socket = TcpStream::connect((domain.as_str(), port)).await; + let socket = try_socket.map_err(Error::Io)?; + client_async_tls_with_connector_and_config(request, socket, connector, config).await +}