diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10552695edf..532d374595f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ env: RUST_BACKTRACE: 1 RUSTFLAGS: -Dwarnings RUSTDOCFLAGS: -Dwarnings - MSRV: "1.81" + MSRV: "1.85" SCCACHE_CACHE_SIZE: "10G" IROH_FORCE_STAGING_RELAYS: "1" @@ -256,7 +256,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-30 + toolchain: nightly-2025-02-20 - name: Install sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 4aac437cf9c..69fda04125d 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -32,7 +32,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: nightly-2024-11-30 + toolchain: nightly-2025-02-20 - name: Install sccache uses: mozilla-actions/sccache-action@v0.0.9 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2a90264184d..1f701f0642c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -40,7 +40,7 @@ env: RUST_BACKTRACE: 1 RUSTFLAGS: -Dwarnings RUSTDOCFLAGS: -Dwarnings - MSRV: "1.81" + MSRV: "1.85" SCCACHE_CACHE_SIZE: "10G" BIN_NAMES: "iroh-relay,iroh-dns-server" RELEASE_VERSION: ${{ github.event.inputs.release_version }} diff --git a/.github/workflows/test_relay_server.yml b/.github/workflows/test_relay_server.yml index ad6afedf67f..37cc252226f 100644 --- a/.github/workflows/test_relay_server.yml +++ b/.github/workflows/test_relay_server.yml @@ -13,7 +13,7 @@ env: RUST_BACKTRACE: 1 RUSTFLAGS: -Dwarnings RUSTDOCFLAGS: -Dwarnings - MSRV: "1.81" + MSRV: "1.85" SCCACHE_CACHE_SIZE: "10G" IROH_FORCE_STAGING_RELAYS: "1" diff --git a/Cargo.lock b/Cargo.lock index 94a9bc5f737..cebcdee701d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "acto" @@ -146,6 +146,9 @@ name = "anyhow" version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +dependencies = [ + "backtrace", +] [[package]] name = "anymap2" @@ -463,6 +466,12 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "102dbef1187b1893e6dfe05a774e79fd52265f49f214f6879c8ff49f52c8188b" +[[package]] +name = "btparse" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387e80962b798815a2b5c4bcfdb6bf626fa922ffe9f74e373103b858738e9f31" + [[package]] name = "bumpalo" version = "3.17.0" @@ -622,6 +631,17 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" +[[package]] +name = "color-backtrace" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2123a5984bd52ca861c66f66a9ab9883b27115c607f801f86c1bc2a84eb69f0f" +dependencies = [ + "backtrace", + "btparse", + "termcolor", +] + [[package]] name = "colorchoice" version = "1.0.3" @@ -2237,6 +2257,8 @@ dependencies = [ "iroh-quinn-udp", "iroh-relay", "n0-future", + "n0-snafu", + "nested_enum_utils", "netdev", "netwatch", "parse-size", @@ -2256,13 +2278,12 @@ dependencies = [ "serde", "serde_json", "smallvec", + "snafu", "spki", "strum", "stun-rs", "surge-ping", "swarm-discovery", - "testresult", - "thiserror 2.0.12", "time", "tokio", "tokio-stream", @@ -2287,6 +2308,8 @@ dependencies = [ "data-encoding", "derive_more", "ed25519-dalek", + "n0-snafu", + "nested_enum_utils", "postcard", "proptest", "rand 0.8.5", @@ -2294,7 +2317,7 @@ dependencies = [ "serde", "serde_json", "serde_test", - "thiserror 2.0.12", + "snafu", "url", ] @@ -2302,7 +2325,6 @@ dependencies = [ name = "iroh-bench" version = "0.35.0" dependencies = [ - "anyhow", "bytes", "clap", "hdrhistogram", @@ -2310,6 +2332,7 @@ dependencies = [ "iroh-metrics", "iroh-quinn", "n0-future", + "n0-snafu", "rand 0.8.5", "rcgen", "rustls", @@ -2322,7 +2345,6 @@ dependencies = [ name = "iroh-dns-server" version = "0.35.0" dependencies = [ - "anyhow", "async-trait", "axum", "axum-server", @@ -2342,6 +2364,7 @@ dependencies = [ "iroh-metrics", "lru", "n0-future", + "n0-snafu", "pkarr", "rand 0.8.5", "rand_chacha 0.3.1", @@ -2351,9 +2374,9 @@ dependencies = [ "rustls", "rustls-pemfile", "serde", + "snafu", "struct_iterable", "strum", - "testresult", "tokio", "tokio-rustls", "tokio-rustls-acme", @@ -2460,7 +2483,6 @@ name = "iroh-relay" version = "0.35.0" dependencies = [ "ahash", - "anyhow", "bytes", "cfg_aliases", "clap", @@ -2482,6 +2504,8 @@ dependencies = [ "iroh-quinn-proto", "lru", "n0-future", + "n0-snafu", + "nested_enum_utils", "num_enum", "pin-project", "pkarr", @@ -2503,10 +2527,9 @@ dependencies = [ "serde_json", "sha1", "simdutf8", + "snafu", "strum", "stun-rs", - "testresult", - "thiserror 2.0.12", "time", "tokio", "tokio-rustls", @@ -2802,6 +2825,19 @@ dependencies = [ "web-time", ] +[[package]] +name = "n0-snafu" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6726cb23c307232453032bca1c4fba5a5c60dc36775211e134ab304ab510b548" +dependencies = [ + "anyhow", + "btparse", + "color-backtrace", + "snafu", + "tracing-error", +] + [[package]] name = "nested_enum_utils" version = "0.2.2" @@ -4421,6 +4457,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" dependencies = [ + "backtrace", "snafu-derive", ] @@ -4681,10 +4718,13 @@ dependencies = [ ] [[package]] -name = "testresult" -version = "0.4.1" +name = "termcolor" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614b328ff036a4ef882c61570f72918f7e9c5bee1da33f8e7f91e01daee7e56c" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] [[package]] name = "thiserror" @@ -5056,6 +5096,16 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-error" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db" +dependencies = [ + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-log" version = "0.2.0" diff --git a/iroh-base/Cargo.toml b/iroh-base/Cargo.toml index 868061bc62a..b5db139b73f 100644 --- a/iroh-base/Cargo.toml +++ b/iroh-base/Cargo.toml @@ -9,7 +9,7 @@ authors = ["n0 team"] repository = "https://github.com/n0-computer/iroh" # Sadly this also needs to be updated in .github/workflows/ci.yml -rust-version = "1.81" +rust-version = "1.85" [lints] workspace = true @@ -23,7 +23,9 @@ url = { version = "2.5.3", features = ["serde"], optional = true } postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"], optional = true } rand_core = { version = "0.6.4", optional = true } serde = { version = "1", features = ["derive", "rc"] } -thiserror = { version = "2", optional = true } +snafu = { version = "0.8.5", features = ["rust_1_81"], optional = true } +n0-snafu = "0.2.0" +nested_enum_utils = "0.2.0" [dev-dependencies] postcard = { version = "1", features = ["use-std"] } @@ -41,7 +43,7 @@ key = [ "dep:ed25519-dalek", "dep:url", "dep:derive_more", - "dep:thiserror", + "dep:snafu", "dep:data-encoding", "dep:rand_core", "relay", @@ -49,7 +51,7 @@ key = [ relay = [ "dep:url", "dep:derive_more", - "dep:thiserror", + "dep:snafu", ] [package.metadata.docs.rs] diff --git a/iroh-base/src/key.rs b/iroh-base/src/key.rs index 5f4a6511b1d..4e99c66cf28 100644 --- a/iroh-base/src/key.rs +++ b/iroh-base/src/key.rs @@ -9,10 +9,12 @@ use std::{ }; use curve25519_dalek::edwards::CompressedEdwardsY; -pub use ed25519_dalek::Signature; -use ed25519_dalek::{SignatureError, SigningKey, VerifyingKey}; +pub use ed25519_dalek::{Signature, SignatureError}; +use ed25519_dalek::{SigningKey, VerifyingKey}; +use nested_enum_utils::common_fields; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; +use snafu::{Backtrace, Snafu}; /// A public key. /// @@ -180,17 +182,26 @@ impl Display for PublicKey { } /// Error when deserialising a [`PublicKey`] or a [`SecretKey`]. -#[derive(thiserror::Error, Debug)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Snafu, Debug)] +#[allow(missing_docs)] +#[snafu(visibility(pub(crate)))] pub enum KeyParsingError { /// Error when decoding. - #[error("decoding: {0}")] - Decode(#[from] data_encoding::DecodeError), + #[snafu(transparent)] + Decode { source: data_encoding::DecodeError }, /// Error when decoding the public key. - #[error("key: {0}")] - Key(#[from] ed25519_dalek::SignatureError), + #[snafu(transparent)] + Key { + source: ed25519_dalek::SignatureError, + }, /// The encoded information had the wrong length. - #[error("invalid length")] - DecodeInvalidLength, + #[snafu(display("invalid length"))] + DecodeInvalidLength {}, } /// Deserialises the [`PublicKey`] from it's base32 encoding. @@ -333,14 +344,14 @@ fn decode_base32_hex(s: &str) -> Result<[u8; 32], KeyParsingError> { let input = s.to_ascii_uppercase(); let input = input.as_bytes(); if data_encoding::BASE32_NOPAD.decode_len(input.len())? != bytes.len() { - return Err(KeyParsingError::DecodeInvalidLength); + return Err(DecodeInvalidLengthSnafu.build()); } data_encoding::BASE32_NOPAD.decode_mut(input, &mut bytes) }; match res { Ok(len) => { if len != PublicKey::LENGTH { - return Err(KeyParsingError::DecodeInvalidLength); + return Err(DecodeInvalidLengthSnafu.build()); } } Err(partial) => return Err(partial.error.into()), diff --git a/iroh-base/src/lib.rs b/iroh-base/src/lib.rs index 2b32e2719b4..af3be651a5d 100644 --- a/iroh-base/src/lib.rs +++ b/iroh-base/src/lib.rs @@ -15,7 +15,7 @@ mod node_addr; mod relay_url; #[cfg(feature = "key")] -pub use self::key::{KeyParsingError, NodeId, PublicKey, SecretKey, Signature}; +pub use self::key::{KeyParsingError, NodeId, PublicKey, SecretKey, Signature, SignatureError}; #[cfg(feature = "key")] pub use self::node_addr::NodeAddr; #[cfg(feature = "relay")] diff --git a/iroh-base/src/relay_url.rs b/iroh-base/src/relay_url.rs index 6f6b4cea975..b1aef365282 100644 --- a/iroh-base/src/relay_url.rs +++ b/iroh-base/src/relay_url.rs @@ -1,6 +1,7 @@ use std::{fmt, ops::Deref, str::FromStr, sync::Arc}; use serde::{Deserialize, Serialize}; +use snafu::{Backtrace, ResultExt, Snafu}; use url::Url; /// A URL identifying a relay server. @@ -38,9 +39,12 @@ impl From for RelayUrl { } /// Can occur when parsing a string into a [`RelayUrl`]. -#[derive(Debug, thiserror::Error)] -#[error("Failed to parse: {0}")] -pub struct RelayUrlParseError(#[from] url::ParseError); +#[derive(Debug, Snafu)] +#[snafu(display("Failed to parse"))] +pub struct RelayUrlParseError { + source: url::ParseError, + backtrace: Option, +} /// Support for parsing strings directly. /// @@ -50,7 +54,7 @@ impl FromStr for RelayUrl { type Err = RelayUrlParseError; fn from_str(s: &str) -> Result { - let inner = Url::from_str(s)?; + let inner = Url::from_str(s).context(RelayUrlParseSnafu)?; Ok(RelayUrl::from(inner)) } } diff --git a/iroh-base/src/ticket.rs b/iroh-base/src/ticket.rs index 0eea9788f74..82c5d3b2658 100644 --- a/iroh-base/src/ticket.rs +++ b/iroh-base/src/ticket.rs @@ -5,7 +5,9 @@ use std::{collections::BTreeSet, net::SocketAddr}; +use nested_enum_utils::common_fields; use serde::{Deserialize, Serialize}; +use snafu::{Backtrace, Snafu}; use crate::{key::NodeId, relay_url::RelayUrl}; @@ -35,7 +37,7 @@ pub trait Ticket: Sized { fn to_bytes(&self) -> Vec; /// Deserialize from the base32 string representation bytes. - fn from_bytes(bytes: &[u8]) -> Result; + fn from_bytes(bytes: &[u8]) -> Result; /// Serialize to string. fn serialize(&self) -> String { @@ -45,10 +47,10 @@ pub trait Ticket: Sized { } /// Deserialize from a string. - fn deserialize(str: &str) -> Result { + fn deserialize(str: &str) -> Result { let expected = Self::KIND; let Some(rest) = str.strip_prefix(expected) else { - return Err(Error::Kind { expected }); + return Err(KindSnafu { expected }.build()); }; let bytes = data_encoding::BASE32_NOPAD.decode(rest.to_ascii_uppercase().as_bytes())?; let ticket = Self::from_bytes(&bytes)?; @@ -57,23 +59,31 @@ pub trait Ticket: Sized { } /// An error deserializing an iroh ticket. -#[derive(Debug, thiserror::Error)] -pub enum Error { +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +#[snafu(visibility(pub(crate)))] +#[non_exhaustive] +pub enum ParseError { /// Found a ticket of with the wrong prefix, indicating the wrong kind. - #[error("wrong prefix, expected {expected}")] + #[snafu(display("wrong prefix, expected {expected}"))] Kind { /// The expected prefix. expected: &'static str, }, /// This looks like a ticket, but postcard deserialization failed. - #[error("deserialization failed: {_0}")] - Postcard(#[from] postcard::Error), + #[snafu(display("deserialization failed"))] + Postcard { source: postcard::Error }, /// This looks like a ticket, but base32 decoding failed. - #[error("decoding failed: {_0}")] - Encoding(#[from] data_encoding::DecodeError), + #[snafu(transparent)] + Encoding { source: data_encoding::DecodeError }, /// Verification of the deserialized bytes failed. - #[error("verification failed: {_0}")] - Verify(&'static str), + #[snafu(display("verification failed: {message}"))] + Verify { message: &'static str }, } #[derive(Serialize, Deserialize)] diff --git a/iroh-base/src/ticket/node.rs b/iroh-base/src/ticket/node.rs index e964def5dd8..8dd8c931c86 100644 --- a/iroh-base/src/ticket/node.rs +++ b/iroh-base/src/ticket/node.rs @@ -3,11 +3,12 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; +use snafu::ResultExt; use super::{Variant0AddrInfo, Variant0NodeAddr}; use crate::{ node_addr::NodeAddr, - ticket::{self, Ticket}, + ticket::{self, ParseError, Ticket}, }; /// A token containing information for establishing a connection to a node. @@ -61,8 +62,8 @@ impl Ticket for NodeTicket { postcard::to_stdvec(&data).expect("postcard serialization failed") } - fn from_bytes(bytes: &[u8]) -> Result { - let res: TicketWireFormat = postcard::from_bytes(bytes).map_err(ticket::Error::Postcard)?; + fn from_bytes(bytes: &[u8]) -> Result { + let res: TicketWireFormat = postcard::from_bytes(bytes).context(ticket::PostcardSnafu)?; let TicketWireFormat::Variant0(Variant0NodeTicket { node }) = res; Ok(Self { node: NodeAddr { @@ -75,7 +76,7 @@ impl Ticket for NodeTicket { } impl FromStr for NodeTicket { - type Err = ticket::Error; + type Err = ParseError; fn from_str(s: &str) -> Result { ticket::Ticket::deserialize(s) diff --git a/iroh-dns-server/Cargo.toml b/iroh-dns-server/Cargo.toml index c0c85f6c95f..af4984c415b 100644 --- a/iroh-dns-server/Cargo.toml +++ b/iroh-dns-server/Cargo.toml @@ -10,7 +10,6 @@ keywords = ["networking", "pkarr", "dns", "dns-server", "iroh"] readme = "README.md" [dependencies] -anyhow = "1.0.80" async-trait = "0.1.77" axum = { version = "0.8", features = ["macros"] } axum-server = { version = "0.7", features = ["tls-rustls-no-provider"] } @@ -32,14 +31,16 @@ humantime-serde = "1.1.1" iroh-metrics = { version = "0.34", features = ["service"] } lru = "0.13" n0-future = "0.1.2" +n0-snafu = "0.2.0" pkarr = { version = "3.7", features = ["relays", "dht"], default-features = false } rcgen = "0.13" -redb = "2.0.0" +redb = "=2.4.0" # 2.5.0 has MSRV 1.85 regex = "1.10.3" rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-pemfile = { version = "2.1" } serde = { version = "1", features = ["derive"] } struct_iterable = "0.1.1" +snafu = { version = "0.8.5", features = ["rust_1_81"] } strum = { version = "0.26", features = ["derive"] } tokio = { version = "1", features = ["full"] } tokio-rustls = { version = "0.26", default-features = false, features = [ @@ -64,7 +65,6 @@ hickory-resolver = "0.25.0" iroh = { path = "../iroh" } rand = "0.8" rand_chacha = "0.3.1" -testresult = "0.4.1" tracing-test = "0.2.5" [[bench]] diff --git a/iroh-dns-server/benches/write.rs b/iroh-dns-server/benches/write.rs index 771b673c0f3..9522a6980e5 100644 --- a/iroh-dns-server/benches/write.rs +++ b/iroh-dns-server/benches/write.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use anyhow::Result; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use iroh::{discovery::pkarr::PkarrRelayClient, node_info::NodeInfo, SecretKey}; use iroh_dns_server::{config::Config, metrics::Metrics, server::Server, ZoneStore}; +use n0_snafu::Result; use rand_chacha::rand_core::SeedableRng; use tokio::runtime::Runtime; diff --git a/iroh-dns-server/examples/convert.rs b/iroh-dns-server/examples/convert.rs index 79173d666bf..c8d4a755f90 100644 --- a/iroh-dns-server/examples/convert.rs +++ b/iroh-dns-server/examples/convert.rs @@ -2,6 +2,7 @@ use std::str::FromStr; use clap::Parser; use iroh::NodeId; +use n0_snafu::{Result, ResultExt}; #[derive(Debug, Parser)] struct Cli { @@ -15,17 +16,17 @@ enum Command { PkarrToNode { z32_pubkey: String }, } -fn main() -> anyhow::Result<()> { +fn main() -> Result<()> { let args = Cli::parse(); match args.command { Command::NodeToPkarr { node_id } => { let node_id = NodeId::from_str(&node_id)?; - let public_key = pkarr::PublicKey::try_from(node_id.as_bytes())?; + let public_key = pkarr::PublicKey::try_from(node_id.as_bytes()).e()?; println!("{}", public_key.to_z32()) } Command::PkarrToNode { z32_pubkey } => { - let public_key = pkarr::PublicKey::try_from(z32_pubkey.as_str())?; - let node_id = NodeId::from_bytes(public_key.as_bytes())?; + let public_key = pkarr::PublicKey::try_from(z32_pubkey.as_str()).e()?; + let node_id = NodeId::from_bytes(public_key.as_bytes()).e()?; println!("{}", node_id) } } diff --git a/iroh-dns-server/examples/publish.rs b/iroh-dns-server/examples/publish.rs index 872f119e01b..11c9bae5032 100644 --- a/iroh-dns-server/examples/publish.rs +++ b/iroh-dns-server/examples/publish.rs @@ -1,6 +1,5 @@ use std::{net::SocketAddr, str::FromStr}; -use anyhow::{Context, Result}; use clap::{Parser, ValueEnum}; use iroh::{ discovery::{ @@ -11,6 +10,7 @@ use iroh::{ node_info::{NodeIdExt, NodeInfo, IROH_TXT_NAME}, NodeId, SecretKey, }; +use n0_snafu::{Result, ResultExt}; use url::Url; const DEV_PKARR_RELAY_URL: &str = "http://localhost:8080/pkarr"; diff --git a/iroh-dns-server/examples/resolve.rs b/iroh-dns-server/examples/resolve.rs index 11299405d85..e1e9e25e813 100644 --- a/iroh-dns-server/examples/resolve.rs +++ b/iroh-dns-server/examples/resolve.rs @@ -1,10 +1,10 @@ -use anyhow::Context; use clap::{Parser, ValueEnum}; use iroh::{ discovery::dns::{N0_DNS_NODE_ORIGIN_PROD, N0_DNS_NODE_ORIGIN_STAGING}, dns::DnsResolver, NodeId, }; +use n0_snafu::{Result, ResultExt}; const DEV_DNS_SERVER: &str = "127.0.0.1:5300"; const DEV_DNS_ORIGIN_DOMAIN: &str = "irohdns.example"; @@ -44,11 +44,12 @@ enum Command { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<()> { let args = Cli::parse(); let resolver = if let Some(host) = args.dns_server { let addr = tokio::net::lookup_host(host) - .await? + .await + .e()? .next() .context("failed to resolve DNS server address")?; DnsResolver::with_nameserver(addr) diff --git a/iroh-dns-server/src/config.rs b/iroh-dns-server/src/config.rs index ba8409d24a0..867d89e07d5 100644 --- a/iroh-dns-server/src/config.rs +++ b/iroh-dns-server/src/config.rs @@ -7,7 +7,7 @@ use std::{ time::Duration, }; -use anyhow::{anyhow, Context, Result}; +use n0_snafu::{Result, ResultExt}; use serde::{Deserialize, Serialize}; use tracing::info; @@ -163,7 +163,7 @@ impl Config { let s = tokio::fs::read_to_string(path.as_ref()) .await .with_context(|| format!("failed to read {}", path.as_ref().to_string_lossy()))?; - let config: Config = toml::from_str(&s)?; + let config: Config = toml::from_str(&s).e()?; Ok(config) } @@ -172,9 +172,9 @@ impl Config { let dir = if let Some(val) = env::var_os("IROH_DNS_DATA_DIR") { PathBuf::from(val) } else { - let path = dirs_next::data_dir().ok_or_else(|| { - anyhow!("operating environment provides no directory for application data") - })?; + let path = dirs_next::data_dir() + .context("operating environment provides no directory for application data")?; + path.join("iroh-dns") }; Ok(dir) diff --git a/iroh-dns-server/src/dns.rs b/iroh-dns-server/src/dns.rs index eb8d27ea17c..679a46d3ae0 100644 --- a/iroh-dns-server/src/dns.rs +++ b/iroh-dns-server/src/dns.rs @@ -8,7 +8,6 @@ use std::{ time::Duration, }; -use anyhow::{anyhow, Result}; use async_trait::async_trait; use bytes::Bytes; use hickory_server::{ @@ -26,6 +25,7 @@ use hickory_server::{ server::{Request, RequestHandler, ResponseHandler, ResponseInfo}, store::in_memory::InMemoryAuthority, }; +use n0_snafu::{format_err, Result, ResultExt}; use serde::{Deserialize, Serialize}; use tokio::{ net::{TcpListener, UdpSocket}, @@ -82,12 +82,12 @@ impl DnsServer { config.port, ); - let socket = UdpSocket::bind(bind_addr).await?; + let socket = UdpSocket::bind(bind_addr).await.e()?; - let socket_addr = socket.local_addr()?; + let socket_addr = socket.local_addr().e()?; server.register_socket(socket); - server.register_listener(TcpListener::bind(bind_addr).await?, TCP_TIMEOUT); + server.register_listener(TcpListener::bind(bind_addr).await.e()?, TCP_TIMEOUT); info!("DNS server listening on {}", bind_addr); Ok(Self { @@ -103,7 +103,7 @@ impl DnsServer { /// Shutdown the server an wait for all tasks to complete. pub async fn shutdown(mut self) -> Result<()> { - self.server.shutdown_gracefully().await?; + self.server.shutdown_gracefully().await.e()?; Ok(()) } @@ -111,7 +111,7 @@ impl DnsServer { /// /// Runs forever unless tasks fail. pub async fn run_until_done(mut self) -> Result<()> { - self.server.block_until_done().await?; + self.server.block_until_done().await.e()?; Ok(()) } } @@ -132,7 +132,8 @@ impl DnsHandler { .origins .iter() .map(Name::from_utf8) - .collect::, _>>()?; + .collect::, _>>() + .e()?; let (static_authority, serial) = create_static_authority(&origins, config)?; let authority = Arc::new(NodeAuthority::new( @@ -158,7 +159,7 @@ impl DnsHandler { let (tx, mut rx) = broadcast::channel(1); let response_handle = Handle(tx); self.handle_request(&request, response_handle).await; - Ok(rx.recv().await?) + rx.recv().await.e() } } @@ -232,9 +233,10 @@ fn create_static_authority( RecordType::SOA, config.default_soa.split_ascii_whitespace(), None, - )? + ) + .e()? .into_soa() - .map_err(|_| anyhow!("Couldn't parse SOA: {}", config.default_soa))?; + .map_err(|_| format_err!("Couldn't parse SOA: {}", config.default_soa))?; let serial = soa.serial(); let mut records = BTreeMap::new(); for name in origins { @@ -258,7 +260,7 @@ fn create_static_authority( ); } if let Some(ns) = &config.rr_ns { - let ns = Name::parse(ns, Some(&Name::root()))?; + let ns = Name::parse(ns, Some(&Name::root())).e()?; push_record( &mut records, serial, @@ -268,7 +270,7 @@ fn create_static_authority( } let static_authority = InMemoryAuthority::new(Name::root(), records, ZoneType::Primary, false) - .map_err(|e| anyhow!(e))?; + .map_err(|e| format_err!("new authority: {}", e))?; Ok((static_authority, serial)) } diff --git a/iroh-dns-server/src/dns/node_authority.rs b/iroh-dns-server/src/dns/node_authority.rs index 3f59fd4fe59..bf217a5eae0 100644 --- a/iroh-dns-server/src/dns/node_authority.rs +++ b/iroh-dns-server/src/dns/node_authority.rs @@ -1,6 +1,5 @@ use std::{fmt, sync::Arc}; -use anyhow::{bail, ensure, Context, Result}; use async_trait::async_trait; use hickory_server::{ authority::{ @@ -14,6 +13,8 @@ use hickory_server::{ server::RequestInfo, store::in_memory::InMemoryAuthority, }; +use n0_snafu::{Result, ResultExt}; +use snafu::whatever; use tracing::{debug, trace}; use crate::{ @@ -40,7 +41,9 @@ impl NodeAuthority { origins: Vec, serial: u32, ) -> Result { - ensure!(!origins.is_empty(), "at least one origin is required"); + if origins.is_empty() { + whatever!("at least one origin is required"); + } let first_origin = LowerName::from(&origins[0]); Ok(Self { static_authority, @@ -181,19 +184,19 @@ fn parse_name_as_pkarr_with_origin( continue; } if name.num_labels() < origin.num_labels() + 1 { - bail!("not a valid pkarr name: missing pubkey"); + whatever!("not a valid pkarr name: missing pubkey"); } trace!("parse {origin}"); let labels = name.iter().rev(); let mut labels_without_origin = labels.skip(origin.num_labels() as usize); let pkey_label = labels_without_origin.next().expect("length checked above"); - let pkey_str = std::str::from_utf8(pkey_label)?; + let pkey_str = std::str::from_utf8(pkey_label).e()?; let pkey = PublicKeyBytes::from_z32(pkey_str).context("not a valid pkarr name: invalid pubkey")?; - let remaining_name = Name::from_labels(labels_without_origin.rev())?; + let remaining_name = Name::from_labels(labels_without_origin.rev()).e()?; return Ok((remaining_name, pkey, origin.clone())); } - bail!("name does not match any allowed origin"); + whatever!("name does not match any allowed origin"); } fn err_refused(e: impl fmt::Debug) -> LookupError { diff --git a/iroh-dns-server/src/http.rs b/iroh-dns-server/src/http.rs index 00df72eac84..32eee2ce47d 100644 --- a/iroh-dns-server/src/http.rs +++ b/iroh-dns-server/src/http.rs @@ -5,7 +5,6 @@ use std::{ time::Instant, }; -use anyhow::{bail, Context, Result}; use axum::{ extract::{ConnectInfo, Request, State}, handler::Handler, @@ -15,7 +14,9 @@ use axum::{ routing::get, Router, }; +use n0_snafu::{Result, ResultExt}; use serde::{Deserialize, Serialize}; +use snafu::whatever; use tokio::{net::TcpListener, task::JoinSet}; use tower_http::{ cors::{self, CorsLayer}, @@ -74,7 +75,7 @@ impl HttpServer { state: AppState, ) -> Result { if http_config.is_none() && https_config.is_none() { - bail!("Either http or https config is required"); + whatever!("Either http or https config is required"); } let app = create_app(state, &rate_limit_config); @@ -88,8 +89,8 @@ impl HttpServer { config.port, ); let app = app.clone(); - let listener = TcpListener::bind(bind_addr).await?.into_std()?; - let bound_addr = listener.local_addr()?; + let listener = TcpListener::bind(bind_addr).await.e()?.into_std().e()?; + let bound_addr = listener.local_addr().e()?; let fut = axum_server::from_tcp(listener) .serve(app.into_make_service_with_connect_info::()); info!("HTTP server listening on {bind_addr}"); @@ -124,8 +125,8 @@ impl HttpServer { ) .await? }; - let listener = TcpListener::bind(bind_addr).await?.into_std()?; - let bound_addr = listener.local_addr()?; + let listener = TcpListener::bind(bind_addr).await.e()?.into_std().e()?; + let bound_addr = listener.local_addr().e()?; let fut = axum_server::from_tcp(listener) .acceptor(acceptor) .serve(app.into_make_service_with_connect_info::()); @@ -165,18 +166,18 @@ impl HttpServer { /// /// Runs forever unless tasks fail. pub async fn run_until_done(mut self) -> Result<()> { - let mut final_res: anyhow::Result<()> = Ok(()); + let mut final_res: Result<()> = Ok(()); while let Some(res) = self.tasks.join_next().await { match res { Ok(Ok(())) => {} Err(err) if err.is_cancelled() => {} Ok(Err(err)) => { warn!(?err, "task failed"); - final_res = Err(anyhow::Error::from(err)); + final_res = Err(err).context("task"); } Err(err) => { warn!(?err, "task panicked"); - final_res = Err(err.into()); + final_res = Err(err).context("join"); } } } diff --git a/iroh-dns-server/src/http/doh.rs b/iroh-dns-server/src/http/doh.rs index 43cf5f53f39..d093c5df1e4 100644 --- a/iroh-dns-server/src/http/doh.rs +++ b/iroh-dns-server/src/http/doh.rs @@ -3,7 +3,6 @@ // This module is mostly copied from // https://github.com/fission-codes/fission-server/blob/main/fission-server/src/routes/doh.rs -use anyhow::anyhow; use axum::{ extract::State, response::{IntoResponse, Response}, @@ -17,6 +16,7 @@ use http::{ header::{CACHE_CONTROL, CONTENT_TYPE}, HeaderValue, StatusCode, }; +use n0_snafu::ResultExt; use super::error::AppResult; use crate::state::AppState; @@ -32,7 +32,7 @@ pub async fn get( DnsRequestQuery(request, accept_type): DnsRequestQuery, ) -> AppResult { let message_bytes = state.dns_handler.answer_request(request).await?; - let message = proto::op::Message::from_bytes(&message_bytes).map_err(|e| anyhow!(e))?; + let message = proto::op::Message::from_bytes(&message_bytes).e()?; let min_ttl = message.answers().iter().map(|rec| rec.ttl()).min(); @@ -49,8 +49,7 @@ pub async fn get( .insert(CONTENT_TYPE, accept_type.to_header_value()); if let Some(min_ttl) = min_ttl { - let maxage = - HeaderValue::from_str(&format!("s-maxage={min_ttl}")).map_err(|e| anyhow!(e))?; + let maxage = HeaderValue::from_str(&format!("s-maxage={min_ttl}")).e()?; response.headers_mut().insert(CACHE_CONTROL, maxage); } diff --git a/iroh-dns-server/src/http/doh/response.rs b/iroh-dns-server/src/http/doh/response.rs index 12399fc122a..2f706646418 100644 --- a/iroh-dns-server/src/http/doh/response.rs +++ b/iroh-dns-server/src/http/doh/response.rs @@ -3,9 +3,10 @@ // This module is mostly copied from // https://github.com/fission-codes/fission-server/blob/394de877fad021260c69fdb1edd7bb4b2f98108c/fission-core/src/dns.rs -use anyhow::{ensure, Result}; use hickory_server::proto; +use n0_snafu::Result; use serde::{Deserialize, Serialize}; +use snafu::ensure_whatever; #[derive(Debug, Serialize, Deserialize)] /// JSON representation of a DNS response @@ -47,17 +48,17 @@ pub struct DnsResponse { impl DnsResponse { /// Create a new JSON response from a DNS message pub fn from_message(message: proto::op::Message) -> Result { - ensure!( + ensure_whatever!( message.message_type() == proto::op::MessageType::Response, "Expected message type to be response" ); - ensure!( + ensure_whatever!( message.query_count() == message.queries().len() as u16, "Query count mismatch" ); - ensure!( + ensure_whatever!( message.answer_count() == message.answers().len() as u16, "Answer count mismatch" ); diff --git a/iroh-dns-server/src/http/error.rs b/iroh-dns-server/src/http/error.rs index 7f8ab542d85..cd2c3acce24 100644 --- a/iroh-dns-server/src/http/error.rs +++ b/iroh-dns-server/src/http/error.rs @@ -49,8 +49,8 @@ impl IntoResponse for AppError { } } -impl From for AppError { - fn from(value: anyhow::Error) -> Self { +impl From for AppError { + fn from(value: n0_snafu::Error) -> Self { Self { status: StatusCode::INTERNAL_SERVER_ERROR, detail: Some(value.to_string()), diff --git a/iroh-dns-server/src/http/pkarr.rs b/iroh-dns-server/src/http/pkarr.rs index 061818e1073..2f02514c040 100644 --- a/iroh-dns-server/src/http/pkarr.rs +++ b/iroh-dns-server/src/http/pkarr.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use axum::{ extract::{Path, State}, response::IntoResponse, diff --git a/iroh-dns-server/src/http/tls.rs b/iroh-dns-server/src/http/tls.rs index a151e9ad65d..0de5a55233c 100644 --- a/iroh-dns-server/src/http/tls.rs +++ b/iroh-dns-server/src/http/tls.rs @@ -5,13 +5,14 @@ use std::{ sync::{Arc, OnceLock}, }; -use anyhow::{bail, Context, Result}; use axum_server::{ accept::Accept, tls_rustls::{RustlsAcceptor, RustlsConfig}, }; use n0_future::{future::Boxed as BoxFuture, FutureExt}; +use n0_snafu::{Result, ResultExt}; use serde::{Deserialize, Serialize}; +use snafu::whatever; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls_acme::{axum::AxumAcceptor, caches::DirCache, AcmeConfig}; use tokio_stream::StreamExt; @@ -74,9 +75,11 @@ impl Acce impl TlsAcceptor { async fn self_signed(domains: Vec) -> Result { - let rcgen::CertifiedKey { cert, key_pair } = rcgen::generate_simple_self_signed(domains)?; - let config = - RustlsConfig::from_der(vec![cert.der().to_vec()], key_pair.serialize_der()).await?; + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(domains).e()?; + let config = RustlsConfig::from_der(vec![cert.der().to_vec()], key_pair.serialize_der()) + .await + .e()?; let acceptor = RustlsAcceptor::new(config); Ok(Self::Manual(acceptor)) } @@ -84,8 +87,9 @@ impl TlsAcceptor { async fn manual(domains: Vec, dir: PathBuf) -> Result { let config = rustls::ServerConfig::builder().with_no_client_auth(); if domains.len() != 1 { - bail!("Multiple domains in manual mode are not supported"); + whatever!("Multiple domains in manual mode are not supported"); } + let keyname = escape_hostname(&domains[0]); let cert_path = dir.join(format!("{keyname}.crt")); let key_path = dir.join(format!("{keyname}.key")); @@ -93,7 +97,7 @@ impl TlsAcceptor { let certs = load_certs(cert_path).await?; let secret_key = load_secret_key(key_path).await?; - let config = config.with_single_cert(certs, secret_key)?; + let config = config.with_single_cert(certs, secret_key).e()?; let config = RustlsConfig::from_config(Arc::new(config)); let acceptor = RustlsAcceptor::new(config); Ok(Self::Manual(acceptor)) @@ -140,7 +144,7 @@ async fn load_certs( .context("cannot open certificate file")?; let mut reader = std::io::Cursor::new(certfile); let certs: Result, std::io::Error> = rustls_pemfile::certs(&mut reader).collect(); - let certs = certs?; + let certs = certs.e()?; Ok(certs) } @@ -169,7 +173,7 @@ async fn load_secret_key( } } - bail!( + whatever!( "no keys found in {} (encrypted keys not supported)", filename.as_ref().display() ); diff --git a/iroh-dns-server/src/lib.rs b/iroh-dns-server/src/lib.rs index 25708fabec6..a2a37b9367d 100644 --- a/iroh-dns-server/src/lib.rs +++ b/iroh-dns-server/src/lib.rs @@ -21,13 +21,12 @@ mod tests { time::Duration, }; - use anyhow::Result; use iroh::{ discovery::pkarr::PkarrRelayClient, dns::DnsResolver, node_info::NodeInfo, RelayUrl, SecretKey, }; + use n0_snafu::{Result, ResultExt}; use pkarr::{SignedPacket, Timestamp}; - use testresult::TestResult; use tracing_test::traced_test; use crate::{ @@ -42,7 +41,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn pkarr_publish_dns_resolve() -> Result<()> { + async fn pkarr_publish_dns_resolve() -> Result { let (server, nameserver, http_url) = Server::spawn_for_tests().await?; let pkarr_relay_url = { let mut url = http_url.clone(); @@ -55,96 +54,98 @@ mod tests { let mut packet = dns::Packet::new_reply(0); // record at root packet.answers.push(dns::ResourceRecord::new( - dns::Name::new("").unwrap(), + dns::Name::new("").e()?, dns::CLASS::IN, 30, - dns::rdata::RData::TXT("hi0".try_into()?), + dns::rdata::RData::TXT("hi0".try_into().unwrap()), )); // record at level one packet.answers.push(dns::ResourceRecord::new( - dns::Name::new("_hello").unwrap(), + dns::Name::new("_hello").e()?, dns::CLASS::IN, 30, - dns::rdata::RData::TXT("hi1".try_into()?), + dns::rdata::RData::TXT("hi1".try_into().unwrap()), )); // record at level two packet.answers.push(dns::ResourceRecord::new( - dns::Name::new("_hello.world").unwrap(), + dns::Name::new("_hello.world").e()?, dns::CLASS::IN, 30, - dns::rdata::RData::TXT("hi2".try_into()?), + dns::rdata::RData::TXT("hi2".try_into().unwrap()), )); // multiple records for same name packet.answers.push(dns::ResourceRecord::new( - dns::Name::new("multiple").unwrap(), + dns::Name::new("multiple").e()?, dns::CLASS::IN, 30, - dns::rdata::RData::TXT("hi3".try_into()?), + dns::rdata::RData::TXT("hi3".try_into().unwrap()), )); packet.answers.push(dns::ResourceRecord::new( - dns::Name::new("multiple").unwrap(), + dns::Name::new("multiple").e()?, dns::CLASS::IN, 30, - dns::rdata::RData::TXT("hi4".try_into()?), + dns::rdata::RData::TXT("hi4".try_into().unwrap()), )); // record of type A packet.answers.push(dns::ResourceRecord::new( - dns::Name::new("").unwrap(), + dns::Name::new("").e()?, dns::CLASS::IN, 30, dns::rdata::RData::A(Ipv4Addr::LOCALHOST.into()), )); // record of type AAAA packet.answers.push(dns::ResourceRecord::new( - dns::Name::new("foo.bar.baz").unwrap(), + dns::Name::new("foo.bar.baz").e()?, dns::CLASS::IN, 30, dns::rdata::RData::AAAA(Ipv6Addr::LOCALHOST.into()), )); - SignedPacket::new(&keypair, &packet.answers, Timestamp::now())? + SignedPacket::new(&keypair, &packet.answers, Timestamp::now()).e()? }; let pkarr_client = pkarr::Client::builder() .no_default_network() - .relays(&[pkarr_relay_url])? - .build()?; - pkarr_client.publish(&signed_packet, None).await?; + .relays(&[pkarr_relay_url]) + .e()? + .build() + .e()?; + pkarr_client.publish(&signed_packet, None).await.e()?; use hickory_server::proto::rr::Name; let pubkey = signed_packet.public_key().to_z32(); let resolver = test_resolver(nameserver); // resolve root record - let name = Name::from_utf8(format!("{pubkey}."))?; + let name = Name::from_utf8(format!("{pubkey}.")).e()?; let res = resolver.lookup_txt(name, DNS_TIMEOUT).await?; let records = res.into_iter().map(|t| t.to_string()).collect::>(); assert_eq!(records, vec!["hi0".to_string()]); // resolve level one record - let name = Name::from_utf8(format!("_hello.{pubkey}."))?; + let name = Name::from_utf8(format!("_hello.{pubkey}.")).e()?; let res = resolver.lookup_txt(name, DNS_TIMEOUT).await?; let records = res.into_iter().map(|t| t.to_string()).collect::>(); assert_eq!(records, vec!["hi1".to_string()]); // resolve level two record - let name = Name::from_utf8(format!("_hello.world.{pubkey}."))?; + let name = Name::from_utf8(format!("_hello.world.{pubkey}.")).e()?; let res = resolver.lookup_txt(name, DNS_TIMEOUT).await?; let records = res.into_iter().map(|t| t.to_string()).collect::>(); assert_eq!(records, vec!["hi2".to_string()]); // resolve multiple records for same name - let name = Name::from_utf8(format!("multiple.{pubkey}."))?; + let name = Name::from_utf8(format!("multiple.{pubkey}.")).e()?; let res = resolver.lookup_txt(name, DNS_TIMEOUT).await?; let records = res.into_iter().map(|t| t.to_string()).collect::>(); assert_eq!(records, vec!["hi3".to_string(), "hi4".to_string()]); // resolve A record - let name = Name::from_utf8(format!("{pubkey}."))?; + let name = Name::from_utf8(format!("{pubkey}.")).e()?; let res = resolver.lookup_ipv4(name, DNS_TIMEOUT).await?; let records = res.collect::>(); assert_eq!(records, vec![Ipv4Addr::LOCALHOST]); // resolve AAAA record - let name = Name::from_utf8(format!("foo.bar.baz.{pubkey}."))?; + let name = Name::from_utf8(format!("foo.bar.baz.{pubkey}.")).e()?; let res = resolver.lookup_ipv6(name, DNS_TIMEOUT).await?; let records = res.collect::>(); assert_eq!(records, vec![Ipv6Addr::LOCALHOST]); @@ -155,7 +156,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn integration_smoke() -> Result<()> { + async fn integration_smoke() -> Result { let (server, nameserver, http_url) = Server::spawn_for_tests().await?; let pkarr_relay = { @@ -187,7 +188,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn store_eviction() -> TestResult<()> { + async fn store_eviction() -> Result { let options = ZoneStoreOptions { eviction: Duration::from_millis(100), eviction_interval: Duration::from_millis(100), @@ -217,9 +218,9 @@ mod tests { #[tokio::test] #[traced_test] - async fn integration_mainline() -> Result<()> { + async fn integration_mainline() -> Result { // run a mainline testnet - let testnet = pkarr::mainline::Testnet::new_async(5).await?; + let testnet = pkarr::mainline::Testnet::new_async(5).await.e()?; let bootstrap = testnet.bootstrap.clone(); // spawn our server with mainline support @@ -240,8 +241,9 @@ mod tests { let pkarr = pkarr::Client::builder() .no_default_network() .dht(|builder| builder.bootstrap(&testnet.bootstrap)) - .build()?; - pkarr.publish(&signed_packet, None).await?; + .build() + .e()?; + pkarr.publish(&signed_packet, None).await.e()?; // resolve via DNS from our server, which will lookup from our DHT let resolver = test_resolver(nameserver); @@ -263,6 +265,7 @@ mod tests { let node_id = secret_key.public(); let relay_url: RelayUrl = "https://relay.example.".parse()?; let node_info = NodeInfo::new(node_id).with_relay_url(Some(relay_url.clone())); - node_info.to_pkarr_signed_packet(&secret_key, 30) + let packet = node_info.to_pkarr_signed_packet(&secret_key, 30)?; + Ok(packet) } } diff --git a/iroh-dns-server/src/main.rs b/iroh-dns-server/src/main.rs index 05f18bc4d35..83dfc514eec 100644 --- a/iroh-dns-server/src/main.rs +++ b/iroh-dns-server/src/main.rs @@ -1,8 +1,8 @@ use std::path::PathBuf; -use anyhow::Result; use clap::Parser; use iroh_dns_server::{config::Config, server::run_with_config_until_ctrl_c}; +use n0_snafu::Result; use tracing::debug; #[derive(Parser, Debug)] diff --git a/iroh-dns-server/src/server.rs b/iroh-dns-server/src/server.rs index 8ee9ceb9e93..dcd7918a95f 100644 --- a/iroh-dns-server/src/server.rs +++ b/iroh-dns-server/src/server.rs @@ -1,9 +1,8 @@ //! The main server which combines the DNS and HTTP(S) servers. - use std::sync::Arc; -use anyhow::Result; use iroh_metrics::service::start_metrics_server; +use n0_snafu::{Result, ResultExt}; use tracing::info; use crate::{ @@ -29,7 +28,7 @@ pub async fn run_with_config_until_ctrl_c(config: Config) -> Result<()> { store = store.with_mainline_fallback(bootstrap); }; let server = Server::spawn(config, store, metrics).await?; - tokio::signal::ctrl_c().await?; + tokio::signal::ctrl_c().await.e()?; info!("shutdown"); server.shutdown().await?; Ok(()) @@ -39,7 +38,7 @@ pub async fn run_with_config_until_ctrl_c(config: Config) -> Result<()> { pub struct Server { http_server: HttpServer, dns_server: DnsServer, - metrics_task: tokio::task::JoinHandle>, + metrics_task: tokio::task::JoinHandle>, } impl Server { @@ -63,7 +62,7 @@ impl Server { if let Some(addr) = metrics_addr { let mut registry = iroh_metrics::Registry::default(); registry.register(metrics); - start_metrics_server(addr, Arc::new(registry)).await?; + start_metrics_server(addr, Arc::new(registry)).await.e()?; } Ok(()) }); @@ -141,7 +140,7 @@ impl Server { let server = Self::spawn(config, store, Default::default()).await?; let dns_addr = server.dns_server.local_addr(); let http_addr = server.http_server.http_addr().expect("http is set"); - let http_url = format!("http://{http_addr}").parse()?; + let http_url = format!("http://{http_addr}").parse::().e()?; Ok((server, dns_addr, http_url)) } } diff --git a/iroh-dns-server/src/store.rs b/iroh-dns-server/src/store.rs index 3831479cbfc..01e13b3962f 100644 --- a/iroh-dns-server/src/store.rs +++ b/iroh-dns-server/src/store.rs @@ -2,9 +2,9 @@ use std::{collections::BTreeMap, num::NonZeroUsize, path::Path, sync::Arc, time::Duration}; -use anyhow::Result; use hickory_server::proto::rr::{Name, RecordSet, RecordType, RrKey}; use lru::LruCache; +use n0_snafu::Result; use pkarr::{Client as PkarrClient, SignedPacket}; use tokio::sync::Mutex; use tracing::{debug, trace, warn}; diff --git a/iroh-dns-server/src/store/signed_packets.rs b/iroh-dns-server/src/store/signed_packets.rs index c775f32ef1a..7047c1394c9 100644 --- a/iroh-dns-server/src/store/signed_packets.rs +++ b/iroh-dns-server/src/store/signed_packets.rs @@ -6,7 +6,7 @@ use std::{ time::{Duration, SystemTime}, }; -use anyhow::{Context, Result}; +use n0_snafu::{format_err, Result, ResultExt}; use pkarr::{SignedPacket, Timestamp}; use redb::{ backends::InMemoryBackend, Database, MultimapTableDefinition, ReadableTable, TableDefinition, @@ -109,7 +109,7 @@ impl Actor { } } - async fn run0(&mut self) -> anyhow::Result<()> { + async fn run0(&mut self) -> Result<()> { let expiry_us = self.options.eviction.as_micros() as u64; while let Some(msg) = self.recv.recv().await { // if we get a snapshot message here we don't need to do a write transaction @@ -122,15 +122,15 @@ impl Actor { }; trace!("batch"); self.recv.push_back(msg).unwrap(); - let transaction = self.db.begin_write()?; - let mut tables = Tables::new(&transaction)?; + let transaction = self.db.begin_write().e()?; + let mut tables = Tables::new(&transaction).e()?; let timeout = tokio::time::sleep(self.options.max_batch_time); tokio::pin!(timeout); for _ in 0..self.options.max_batch_size { tokio::select! { _ = self.cancel.cancelled() => { drop(tables); - transaction.commit()?; + transaction.commit().e()?; return Ok(()); } _ = &mut timeout => break, @@ -157,15 +157,17 @@ impl Actor { continue; } else { // remove the old packet from the update time index - tables.update_time.remove(&existing.timestamp().to_bytes(), key.as_bytes())?; + tables.update_time.remove(&existing.timestamp().to_bytes(), key.as_bytes()).e()?; true } } else { false }; let value = packet.serialize(); - tables.signed_packets.insert(key.as_bytes(), &value[..])?; - tables.update_time.insert(&packet.timestamp().to_bytes(), key.as_bytes())?; + tables.signed_packets + .insert(key.as_bytes(), &value[..]).e()?; + tables.update_time + .insert(&packet.timestamp().to_bytes(), key.as_bytes()).e()?; if replaced { self.metrics.store_packets_updated.inc(); } else { @@ -175,9 +177,9 @@ impl Actor { } Message::Remove { key, res } => { trace!("remove {}", key); - let updated = if let Some(row) = tables.signed_packets.remove(key.as_bytes())? { - let packet = SignedPacket::deserialize(row.value())?; - tables.update_time.remove(&packet.timestamp().to_bytes(), key.as_bytes())?; + let updated = if let Some(row) = tables.signed_packets.remove(key.as_bytes()).e()? { + let packet = SignedPacket::deserialize(row.value()).e()?; + tables.update_time.remove(&packet.timestamp().to_bytes(), key.as_bytes()).e()?; self.metrics.store_packets_removed.inc(); true } else { @@ -194,17 +196,17 @@ impl Actor { if let Some(packet) = get_packet(&tables.signed_packets, &key)? { let expired = Timestamp::now() - expiry_us; if packet.timestamp() < expired { - tables.update_time.remove(&time.to_bytes(), key.as_bytes())?; - let _ = tables.signed_packets.remove(key.as_bytes())?; + tables.update_time.remove(&time.to_bytes(), key.as_bytes()).e()?; + let _ = tables.signed_packets.remove(key.as_bytes()).e()?; self.metrics.store_packets_expired.inc(); debug!("removed expired packet {key}"); } else { debug!("packet {key} is no longer expired, removing obsolete expiry entry"); - tables.update_time.remove(&time.to_bytes(), key.as_bytes())?; + tables.update_time.remove(&time.to_bytes(), key.as_bytes()).e()?; } } else { debug!("expired packet {key} not found, remove from expiry table"); - tables.update_time.remove(&time.to_bytes(), key.as_bytes())?; + tables.update_time.remove(&time.to_bytes(), key.as_bytes()).e()?; } } } @@ -212,7 +214,7 @@ impl Actor { } } drop(tables); - transaction.commit()?; + transaction.commit().e()?; } Ok(()) } @@ -246,10 +248,10 @@ pub(super) struct Snapshot { impl Snapshot { pub fn new(db: &Database) -> Result { - let tx = db.begin_read()?; + let tx = db.begin_read().e()?; Ok(Self { - signed_packets: tx.open_table(SIGNED_PACKETS_TABLE)?, - update_time: tx.open_multimap_table(UPDATE_TIME_TABLE)?, + signed_packets: tx.open_table(SIGNED_PACKETS_TABLE).e()?, + update_time: tx.open_multimap_table(UPDATE_TIME_TABLE).e()?, }) } } @@ -278,15 +280,17 @@ impl SignedPacketStore { pub fn in_memory(options: Options, metrics: Arc) -> Result { info!("using in-memory packet database"); - let db = Database::builder().create_with_backend(InMemoryBackend::new())?; + let db = Database::builder() + .create_with_backend(InMemoryBackend::new()) + .e()?; Self::open(db, options, metrics) } pub fn open(db: Database, options: Options, metrics: Arc) -> Result { // create tables - let write_tx = db.begin_write()?; - let _ = Tables::new(&write_tx)?; - write_tx.commit()?; + let write_tx = db.begin_write().e()?; + let _ = Tables::new(&write_tx).e()?; + write_tx.commit().e()?; let (send, recv) = mpsc::channel(1024); let send2 = send.clone(); let cancel = CancellationToken::new(); @@ -315,22 +319,29 @@ impl SignedPacketStore { pub async fn upsert(&self, packet: SignedPacket) -> Result { let (tx, rx) = oneshot::channel(); - self.send.send(Message::Upsert { packet, res: tx }).await?; - Ok(rx.await?) + self.send + .send(Message::Upsert { packet, res: tx }) + .await + .e()?; + rx.await.e() } pub async fn get(&self, key: &PublicKeyBytes) -> Result> { let (tx, rx) = oneshot::channel(); - self.send.send(Message::Get { key: *key, res: tx }).await?; - Ok(rx.await?) + self.send + .send(Message::Get { key: *key, res: tx }) + .await + .e()?; + rx.await.e() } pub async fn remove(&self, key: &PublicKeyBytes) -> Result { let (tx, rx) = oneshot::channel(); self.send .send(Message::Remove { key: *key, res: tx }) - .await?; - Ok(rx.await?) + .await + .e()?; + rx.await.e() } } @@ -354,7 +365,7 @@ fn get_packet( buf.extend(data); match SignedPacket::deserialize(&buf) { Ok(packet) => Ok(Some(packet)), - Err(err2) => Err(anyhow::anyhow!("Failed to decode as pkarr v3: {err:#}. Also failed to decode as pkarr v2: {err2:#}")) + Err(err2) => Err(format_err!("Failed to decode as pkarr v3: {err:#}. Also failed to decode as pkarr v2: {err2:#}")) } } } @@ -375,20 +386,19 @@ async fn evict_task(send: mpsc::Sender, options: Options, cancel: Cance } /// Periodically check for expired packets and remove them. -async fn evict_task_inner(send: mpsc::Sender, options: Options) -> anyhow::Result<()> { +async fn evict_task_inner(send: mpsc::Sender, options: Options) -> Result<()> { let expiry_us = options.eviction.as_micros() as u64; loop { let (tx, rx) = oneshot::channel(); let _ = send.send(Message::Snapshot { res: tx }).await.ok(); // if we can't get the snapshot we exit the loop, main actor dead - let Ok(snapshot) = rx.await else { - anyhow::bail!("failed to get snapshot"); - }; + let snapshot = rx.await.context("failed to get snapshot")?; + let expired = Timestamp::now() - expiry_us; trace!("evicting packets older than {}", fmt_time(expired)); // if getting the range fails we exit the loop and shut down // if individual reads fail we log the error and limp on - for item in snapshot.update_time.range(..expired.to_bytes())? { + for item in snapshot.update_time.range(..expired.to_bytes()).e()? { let (time, keys) = match item { Ok(v) => v, Err(e) => { @@ -411,8 +421,9 @@ async fn evict_task_inner(send: mpsc::Sender, options: Options) -> anyh } }; let key = PublicKeyBytes::new(key.value()); + debug!("evicting expired packet {} {}", fmt_time(time), key); - send.send(Message::CheckExpired { time, key }).await?; + send.send(Message::CheckExpired { time, key }).await.e()?; } } // sleep for the eviction interval so we don't constantly check @@ -440,7 +451,7 @@ impl IoThread { F: FnOnce() -> Fut + Send + 'static, Fut: Future, { - let rt = tokio::runtime::Handle::try_current()?; + let rt = tokio::runtime::Handle::try_current().context("get tokio handle")?; let handle = std::thread::Builder::new() .name(name.into()) .spawn(move || rt.block_on(f())) diff --git a/iroh-dns-server/src/util.rs b/iroh-dns-server/src/util.rs index e8eb7fa906d..4625f269ac2 100644 --- a/iroh-dns-server/src/util.rs +++ b/iroh-dns-server/src/util.rs @@ -5,7 +5,6 @@ use std::{ sync::Arc, }; -use anyhow::{anyhow, Result}; use hickory_server::proto::{ op::Message, rr::{ @@ -14,6 +13,7 @@ use hickory_server::proto::{ }, serialize::binary::BinDecodable, }; +use n0_snafu::{Error, Result, ResultExt}; use pkarr::SignedPacket; #[derive( @@ -27,8 +27,8 @@ impl PublicKeyBytes { } pub fn from_z32(s: &str) -> Result { - let bytes = z32::decode(s.as_bytes())?; - let bytes: [u8; 32] = bytes.try_into().map_err(|_| anyhow!("invalid length"))?; + let bytes = z32::decode(s.as_bytes()).e()?; + let bytes = TryInto::<[u8; 32]>::try_into(&bytes[..]).context("invalid length")?; Ok(Self(bytes)) } @@ -68,14 +68,14 @@ impl From for PublicKeyBytes { } impl TryFrom for pkarr::PublicKey { - type Error = anyhow::Error; + type Error = Error; fn try_from(value: PublicKeyBytes) -> Result { - pkarr::PublicKey::try_from(&value.0).map_err(anyhow::Error::from) + pkarr::PublicKey::try_from(&value.0).e() } } impl FromStr for PublicKeyBytes { - type Err = anyhow::Error; + type Err = Error; fn from_str(s: &str) -> Result { Self::from_z32(s) } @@ -89,7 +89,7 @@ impl AsRef<[u8; 32]> for PublicKeyBytes { pub fn signed_packet_to_hickory_message(signed_packet: &SignedPacket) -> Result { let encoded = signed_packet.encoded_packet(); - let message = Message::from_bytes(&encoded)?; + let message = Message::from_bytes(&encoded).e()?; Ok(message) } @@ -97,7 +97,7 @@ pub fn signed_packet_to_hickory_records_without_origin( signed_packet: &SignedPacket, filter: impl Fn(&Record) -> bool, ) -> Result<(Label, BTreeMap>)> { - let common_zone = Label::from_utf8(&signed_packet.public_key().to_z32())?; + let common_zone = Label::from_utf8(&signed_packet.public_key().to_z32()).e()?; let mut message = signed_packet_to_hickory_message(signed_packet)?; let answers = message.take_answers(); let mut output: BTreeMap> = BTreeMap::new(); @@ -111,7 +111,7 @@ pub fn signed_packet_to_hickory_records_without_origin( if name.num_labels() < 1 { continue; } - let zone = name.iter().next_back().unwrap().into_label()?; + let zone = name.iter().next_back().unwrap().into_label().e()?; if zone != common_zone { continue; } @@ -120,7 +120,7 @@ pub fn signed_packet_to_hickory_records_without_origin( } let name_without_zone = - Name::from_labels(name.iter().take(name.num_labels() as usize - 1))?; + Name::from_labels(name.iter().take(name.num_labels() as usize - 1)).e()?; record.set_name(name_without_zone); let rrkey = RrKey::new(record.name().into(), record.record_type()); @@ -145,7 +145,7 @@ pub fn record_set_append_origin( origin: &Name, serial: u32, ) -> Result { - let new_name = input.name().clone().append_name(origin)?; + let new_name = input.name().clone().append_name(origin).e()?; let mut output = RecordSet::new(new_name.clone(), input.record_type(), serial); // TODO: less clones for record in input.records_without_rrsigs() { diff --git a/iroh-relay/Cargo.toml b/iroh-relay/Cargo.toml index 0e56551fb9a..a3681693607 100644 --- a/iroh-relay/Cargo.toml +++ b/iroh-relay/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" authors = ["n0 team"] repository = "https://github.com/n0-computer/iroh" keywords = ["networking", "holepunching", "p2p"] -rust-version = "1.81" +rust-version = "1.85" [lib] # We need "cdylib" to actually generate .wasm files when we run with --target=wasm32-unknown-unknown. @@ -19,7 +19,6 @@ crate-type = ["lib", "cdylib"] workspace = true [dependencies] -anyhow = { version = "1" } bytes = "1.7" derive_more = { version = "1.0.0", features = [ "debug", @@ -53,7 +52,6 @@ rustls = { version = "0.23", default-features = false, features = ["ring"] } serde = { version = "1", features = ["derive", "rc"] } strum = { version = "0.26", features = ["derive"] } stun-rs = "0.1.11" -thiserror = "2" tokio = { version = "1", features = [ "io-util", "macros", @@ -74,6 +72,9 @@ webpki_types = { package = "rustls-pki-types", version = "1.12" } data-encoding = "2.6.0" lru = "0.13" z32 = "1.0.3" +snafu = { version = "0.8.5", features = ["rust_1_81"] } +n0-snafu = "0.2.0" +nested_enum_utils = "0.2.0" # server feature clap = { version = "4", features = ["derive"], optional = true } @@ -122,7 +123,6 @@ clap = { version = "4", features = ["derive"] } crypto_box = { version = "0.9.1", features = ["serde", "chacha20"] } proptest = "1.2.0" rand_chacha = "0.3.1" -testresult = "0.4.0" tokio = { version = "1", features = [ "io-util", "sync", diff --git a/iroh-relay/proptest-regressions/protos/relay.txt b/iroh-relay/proptest-regressions/protos/relay.txt new file mode 100644 index 00000000000..e718aef9a1e --- /dev/null +++ b/iroh-relay/proptest-regressions/protos/relay.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 9295f5287162dfb180e5826e563c2cea08b477b803ef412ff8351eb5c3eb45ef # shrinks to frame = KeepAlive diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index bf0d4680652..ba73d75bd3d 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -9,23 +9,25 @@ use std::{ task::{self, Poll}, }; -use anyhow::{anyhow, bail, Result}; use conn::Conn; use iroh_base::{RelayUrl, SecretKey}; use n0_future::{ split::{split, SplitSink, SplitStream}, - Sink, Stream, + time, Sink, Stream, }; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, Snafu}; #[cfg(any(test, feature = "test-utils"))] use tracing::warn; use tracing::{debug, event, trace, Level}; use url::Url; -pub use self::conn::{ConnSendError, ReceivedMessage, SendMessage}; +pub use self::conn::{ReceivedMessage, RecvError, SendError, SendMessage}; #[cfg(not(wasm_browser))] -use crate::dns::DnsResolver; +use crate::dns::{DnsError, DnsResolver}; use crate::{ http::{Protocol, RELAY_PATH}, + protos::relay::SendError as SendRelayError, KeyCache, }; @@ -37,6 +39,82 @@ mod tls; #[cfg(not(wasm_browser))] mod util; +/// Connection errors. +/// +/// `ConnectError` contains `DialError`, errors that can occur while dialing the +/// relay, as well as errors that occur while creating or maintaining a connection. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ConnectError { + #[snafu(display("Invalid URL for websocket: {url}"))] + InvalidWebsocketUrl { url: Url }, + #[snafu(display("Invalid relay URL: {url}"))] + InvalidRelayUrl { url: Url }, + #[snafu(transparent)] + Websocket { + #[cfg(not(wasm_browser))] + source: tokio_websockets::Error, + #[cfg(wasm_browser)] + source: ws_stream_wasm::WsErr, + }, + #[snafu(transparent)] + Handshake { source: SendRelayError }, + #[snafu(transparent)] + Dial { source: DialError }, + #[snafu(display("Unexpected status during upgrade: {code}"))] + UnexpectedUpgradeStatus { code: hyper::StatusCode }, + #[snafu(display("Failed to upgrade response"))] + Upgrade { source: hyper::Error }, + #[snafu(display("Invalid TLS servername"))] + InvalidTlsServername {}, + #[snafu(display("No local address available"))] + NoLocalAddr {}, + #[snafu(display("tls connection failed"))] + Tls { source: std::io::Error }, + #[cfg(wasm_browser)] + #[snafu(display("The relay protocol is not available in browsers"))] + RelayProtoNotAvailable {}, +} + +/// Errors that can occur while dialing the relay server. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum DialError { + #[snafu(display("Invliad target port"))] + InvalidTargetPort {}, + #[snafu(transparent)] + #[cfg(not(wasm_browser))] + Dns { source: DnsError }, + #[snafu(transparent)] + Timeout { source: time::Elapsed }, + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(display("Invalid URL: {url}"))] + InvalidUrl { url: Url }, + #[snafu(display("Failed proxy connection: {status}"))] + ProxyConnectInvalidStatus { status: hyper::StatusCode }, + #[snafu(display("Invalid Proxy URL {proxy_url}"))] + ProxyInvalidUrl { proxy_url: Url }, + #[snafu(display("failed to establish proxy connection"))] + ProxyConnect { source: hyper::Error }, + #[snafu(display("Invalid proxy TLS servername: {proxy_hostname}"))] + ProxyInvalidTlsServername { proxy_hostname: String }, + #[snafu(display("Invalid proxy target port"))] + ProxyInvalidTargetPort {}, +} + /// Build a Client. #[derive(derive_more::Debug, Clone)] pub struct ClientBuilder { @@ -138,7 +216,7 @@ impl ClientBuilder { } /// Establishes a new connection to the relay server. - pub async fn connect(&self) -> Result { + pub async fn connect(&self) -> Result { let (conn, local_addr) = match self.protocol { #[cfg(wasm_browser)] Protocol::Websocket => { @@ -157,9 +235,7 @@ impl ClientBuilder { (conn, Some(local_addr)) } #[cfg(wasm_browser)] - Protocol::Relay => { - bail!("Can only connect to relay using websockets in browsers."); - } + Protocol::Relay => return Err(RelayProtoNotAvailableSnafu.build()), }; event!( @@ -174,7 +250,7 @@ impl ClientBuilder { } #[cfg(wasm_browser)] - async fn connect_ws(&self) -> Result { + async fn connect_ws(&self) -> Result { let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); // The relay URL is exchanged with the http(s) scheme in tickets and similar. @@ -185,7 +261,12 @@ impl ClientBuilder { "ws" => "ws", _ => "wss", }) - .map_err(|()| anyhow!("Invalid URL"))?; + .map_err(|_| { + InvalidWebsocketUrlSnafu { + url: dial_url.clone(), + } + .build() + })?; debug!(%dial_url, "Dialing relay by websocket"); @@ -218,7 +299,7 @@ impl Client { } impl Stream for Client { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.conn).poll_next(cx) @@ -226,7 +307,7 @@ impl Stream for Client { } impl Sink for Client { - type Error = ConnSendError; + type Error = SendError; fn poll_ready( mut self: Pin<&mut Self>, @@ -261,7 +342,7 @@ pub struct ClientSink { } impl Sink for ClientSink { - type Error = ConnSendError; + type Error = SendError; fn poll_ready( mut self: Pin<&mut Self>, @@ -304,7 +385,7 @@ impl ClientStream { } impl Stream for ClientStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_next(cx) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index e224e3374bd..2cc59b00343 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -5,19 +5,24 @@ use std::{ io, pin::Pin, + str::Utf8Error, task::{ready, Context, Poll}, }; -use anyhow::{bail, Result}; use bytes::Bytes; use iroh_base::{NodeId, SecretKey}; use n0_future::{time::Duration, Sink, Stream}; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; #[cfg(not(wasm_browser))] use tokio_util::codec::Framed; use tracing::debug; use super::KeyCache; -use crate::protos::relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION}; +use crate::protos::relay::{ + ClientInfo, Frame, RecvError as RecvRelayError, SendError as SendRelayError, MAX_PACKET_SIZE, + PROTOCOL_VERSION, +}; #[cfg(not(wasm_browser))] use crate::{ client::streams::{MaybeTlsStream, MaybeTlsStreamChained, ProxyStream}, @@ -25,42 +30,58 @@ use crate::{ }; /// Error for sending messages to the relay server. -#[derive(Debug, thiserror::Error)] -pub enum ConnSendError { - /// An IO error. - #[error("IO error")] - Io(#[from] io::Error), - /// A protocol error. - #[error("Protocol error")] - Protocol(&'static str), -} - -#[cfg(wasm_browser)] -impl From for ConnSendError { - fn from(err: ws_stream_wasm::WsErr) -> Self { - use std::io::ErrorKind::*; - - use ws_stream_wasm::WsErr::*; - let kind = match err { - ConnectionNotOpen => NotConnected, - ReasonStringToLong | InvalidCloseCode { .. } | InvalidUrl { .. } => InvalidInput, - UnknownDataType | InvalidEncoding => InvalidData, - ConnectionFailed { .. } => ConnectionReset, - _ => Other, - }; - Self::Io(std::io::Error::new(kind, err.to_string())) - } +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum SendError { + #[cfg(not(wasm_browser))] + #[snafu(transparent)] + RelayIo { source: io::Error }, + #[snafu(transparent)] + WebsocketIo { + #[cfg(not(wasm_browser))] + source: tokio_websockets::Error, + #[cfg(wasm_browser)] + source: ws_stream_wasm::WsErr, + }, + #[snafu(display("Exceeds max packet size ({MAX_PACKET_SIZE}): {size}"))] + ExceedsMaxPacketSize { size: usize }, } -#[cfg(not(wasm_browser))] -impl From for ConnSendError { - fn from(err: tokio_websockets::Error) -> Self { - let io_err = match err { - tokio_websockets::Error::Io(io_err) => io_err, - _ => std::io::Error::other(err.to_string()), - }; - Self::Io(io_err) - } +/// Errors when receiving messages from the relay server. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum RecvError { + #[snafu(transparent)] + Io { source: io::Error }, + #[snafu(transparent)] + ProtocolSend { source: SendRelayError }, + #[snafu(transparent)] + ProtocolRecv { source: RecvRelayError }, + #[snafu(transparent)] + Websocket { + #[cfg(not(wasm_browser))] + source: tokio_websockets::Error, + #[cfg(wasm_browser)] + source: ws_stream_wasm::WsErr, + }, + #[snafu(display("invalid protocol message encoding"))] + InvalidProtocolMessageEncoding { source: Utf8Error }, + #[snafu(display("Unexpected frame received: {frame_type}"))] + UnexpectedFrame { + frame_type: crate::protos::relay::FrameType, + }, } /// A connection to a relay server. @@ -102,7 +123,7 @@ impl Conn { conn: ws_stream_wasm::WsStream, key_cache: KeyCache, secret_key: &SecretKey, - ) -> Result { + ) -> Result { let mut conn = Self::WsBrowser { conn, key_cache }; // exchange information with the server @@ -117,7 +138,7 @@ impl Conn { conn: MaybeTlsStreamChained, key_cache: KeyCache, secret_key: &SecretKey, - ) -> Result { + ) -> Result { let conn = Framed::new(conn, RelayCodec::new(key_cache)); let mut conn = Self::Relay { conn }; @@ -133,7 +154,7 @@ impl Conn { conn: tokio_websockets::WebSocketStream>, key_cache: KeyCache, secret_key: &SecretKey, - ) -> Result { + ) -> Result { let mut conn = Self::Ws { conn, key_cache }; // exchange information with the server @@ -144,7 +165,7 @@ impl Conn { } /// Sends the server handshake message. -async fn server_handshake(writer: &mut Conn, secret_key: &SecretKey) -> Result<()> { +async fn server_handshake(writer: &mut Conn, secret_key: &SecretKey) -> Result<(), SendRelayError> { debug!("server_handshake: started"); let client_info = ClientInfo { version: PROTOCOL_VERSION, @@ -157,7 +178,7 @@ async fn server_handshake(writer: &mut Conn, secret_key: &SecretKey) -> Result<( } impl Stream for Conn { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { @@ -167,7 +188,7 @@ impl Stream for Conn { let message = ReceivedMessage::try_from(frame); Poll::Ready(Some(message)) } - Some(Err(err)) => Poll::Ready(Some(Err(err))), + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), }, #[cfg(not(wasm_browser))] @@ -189,7 +210,8 @@ impl Stream for Conn { return Poll::Pending; } let frame = Frame::decode_from_ws_msg(msg.into_payload().into(), key_cache)?; - Poll::Ready(Some(ReceivedMessage::try_from(frame))) + let message = ReceivedMessage::try_from(frame); + Poll::Ready(Some(message)) } Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), None => Poll::Ready(None), @@ -214,7 +236,7 @@ impl Stream for Conn { } impl Sink for Conn { - type Error = ConnSendError; + type Error = SendError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { @@ -232,7 +254,7 @@ impl Sink for Conn { fn start_send(mut self: Pin<&mut Self>, frame: Frame) -> Result<(), Self::Error> { if let Frame::SendPacket { dst_key: _, packet } = &frame { if packet.len() > MAX_PACKET_SIZE { - return Err(ConnSendError::Protocol("Packet exceeds MAX_PACKET_SIZE")); + return Err(ExceedsMaxPacketSizeSnafu { size: packet.len() }.build()); } } match *self { @@ -279,7 +301,7 @@ impl Sink for Conn { } impl Sink for Conn { - type Error = ConnSendError; + type Error = SendError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { @@ -296,9 +318,8 @@ impl Sink for Conn { fn start_send(mut self: Pin<&mut Self>, item: SendMessage) -> Result<(), Self::Error> { if let SendMessage::SendPacket(_, bytes) = &item { - if bytes.len() > MAX_PACKET_SIZE { - return Err(ConnSendError::Protocol("Packet exceeds MAX_PACKET_SIZE")); - } + let size = bytes.len(); + snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); } let frame = Frame::from(item); match *self { @@ -394,7 +415,7 @@ pub enum ReceivedMessage { } impl TryFrom for ReceivedMessage { - type Error = anyhow::Error; + type Error = RecvError; fn try_from(frame: Frame) -> std::result::Result { match frame { @@ -414,7 +435,9 @@ impl TryFrom for ReceivedMessage { Frame::Ping { data } => Ok(ReceivedMessage::Ping(data)), Frame::Pong { data } => Ok(ReceivedMessage::Pong(data)), Frame::Health { problem } => { - let problem = std::str::from_utf8(&problem)?.to_owned(); + let problem = std::str::from_utf8(&problem) + .context(InvalidProtocolMessageEncodingSnafu)? + .to_owned(); let problem = Some(problem); Ok(ReceivedMessage::Health { problem }) } @@ -429,7 +452,10 @@ impl TryFrom for ReceivedMessage { try_for, }) } - _ => bail!("unexpected packet: {:?}", frame.typ()), + _ => Err(UnexpectedFrameSnafu { + frame_type: frame.typ(), + } + .build()), } } } diff --git a/iroh-relay/src/client/streams.rs b/iroh-relay/src/client/streams.rs index 12e570b10cc..3d5f28cd850 100644 --- a/iroh-relay/src/client/streams.rs +++ b/iroh-relay/src/client/streams.rs @@ -4,7 +4,6 @@ use std::{ task::{Context, Poll}, }; -use anyhow::{bail, Result}; use bytes::Bytes; use hyper::upgrade::{Parts, Upgraded}; use hyper_util::rt::TokioIo; @@ -90,24 +89,32 @@ impl AsyncWrite for MaybeTlsStreamChained { } } -pub fn downcast_upgrade(upgraded: Upgraded) -> Result { +pub fn downcast_upgrade(upgraded: Upgraded) -> Option { match upgraded.downcast::>>() { Ok(Parts { read_buf, io, .. }) => { // Prepend data to the reader to avoid data loss let read_buf = std::io::Cursor::new(read_buf); match io.into_inner() { MaybeTlsStream::Raw(conn) => { - Ok(MaybeTlsStreamChained::Raw(util::chain(read_buf, conn))) + Some(MaybeTlsStreamChained::Raw(util::chain(read_buf, conn))) } MaybeTlsStream::Tls(conn) => { - Ok(MaybeTlsStreamChained::Tls(util::chain(read_buf, conn))) + Some(MaybeTlsStreamChained::Tls(util::chain(read_buf, conn))) } } } - Err(_) => { - bail!( - "could not downcast the upgraded connection to a TcpStream or client::TlsStream" - ); + Err(upgraded) => { + if let Ok(Parts { read_buf, io, .. }) = + upgraded.downcast::>>() + { + let conn = io.into_inner(); + + // Prepend data to the reader to avoid data loss + let conn = util::chain(std::io::Cursor::new(read_buf), conn); + return Some(MaybeTlsStreamChained::Tls(conn)); + } + + None } } } diff --git a/iroh-relay/src/client/tls.rs b/iroh-relay/src/client/tls.rs index 83bdd973af3..fc3ba7eded0 100644 --- a/iroh-relay/src/client/tls.rs +++ b/iroh-relay/src/client/tls.rs @@ -14,7 +14,6 @@ // Based on tailscale/derp/derphttp/derphttp_client.go -use anyhow::Context; use bytes::Bytes; use data_encoding::BASE64URL; use http_body_util::Empty; @@ -26,6 +25,7 @@ use hyper::{ }; use n0_future::{task, time}; use rustls::client::Resumption; +use snafu::{OptionExt, ResultExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info_span, Instrument}; @@ -73,7 +73,7 @@ impl MaybeTlsStreamBuilder { self } - pub async fn connect(self) -> Result> { + pub async fn connect(self) -> Result, ConnectError> { let roots = rustls::RootCertStore { roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), }; @@ -98,7 +98,7 @@ impl MaybeTlsStreamBuilder { let local_addr = tcp_stream .local_addr() - .context("No local addr for TCP stream")?; + .map_err(|_| NoLocalAddrSnafu.build())?; debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected"); @@ -106,9 +106,13 @@ impl MaybeTlsStreamBuilder { debug!("Starting TLS handshake"); let hostname = self .tls_servername() - .ok_or_else(|| anyhow!("No tls servername"))?; + .ok_or_else(|| InvalidTlsServernameSnafu.build())?; + let hostname = hostname.to_owned(); - let tls_stream = tls_connector.connect(hostname, tcp_stream).await?; + let tls_stream = tls_connector + .connect(hostname, tcp_stream) + .await + .context(TlsSnafu)?; debug!("tls_connector connect success"); Ok(MaybeTlsStream::Tls(tls_stream)) } else { @@ -128,12 +132,15 @@ impl MaybeTlsStreamBuilder { } fn tls_servername(&self) -> Option { - self.url - .host_str() - .and_then(|s| rustls::pki_types::ServerName::try_from(s).ok()) + let host_str = self.url.host_str()?; + let servername = rustls::pki_types::ServerName::try_from(host_str).ok()?; + Some(servername) } - async fn dial_url(&self, tls_connector: &tokio_rustls::TlsConnector) -> Result { + async fn dial_url( + &self, + tls_connector: &tokio_rustls::TlsConnector, + ) -> Result { if let Some(ref proxy) = self.proxy_url { let stream = self.dial_url_proxy(proxy.clone(), tls_connector).await?; Ok(ProxyStream::Proxied(stream)) @@ -143,7 +150,7 @@ impl MaybeTlsStreamBuilder { } } - async fn dial_url_direct(&self) -> Result { + async fn dial_url_direct(&self) -> Result { use tokio::net::TcpStream; debug!(%self.url, "dial url"); let dst_ip = self @@ -151,7 +158,7 @@ impl MaybeTlsStreamBuilder { .resolve_host(&self.url, self.prefer_ipv6, DNS_TIMEOUT) .await?; - let port = url_port(&self.url).ok_or_else(|| anyhow!("Missing URL port"))?; + let port = url_port(&self.url).context(InvalidTargetPortSnafu)?; let addr = SocketAddr::new(dst_ip, port); debug!("connecting to {}", addr); @@ -159,9 +166,8 @@ impl MaybeTlsStreamBuilder { DIAL_NODE_TIMEOUT, async move { TcpStream::connect(addr).await }, ) - .await - .context("Timeout connecting")? - .context("Failed connecting")?; + .await??; + tcp_stream.set_nodelay(true)?; Ok(tcp_stream) @@ -171,7 +177,8 @@ impl MaybeTlsStreamBuilder { &self, proxy_url: Url, tls_connector: &tokio_rustls::TlsConnector, - ) -> Result, MaybeTlsStream>> { + ) -> Result, MaybeTlsStream>, DialError> + { use hyper_util::rt::TokioIo; use tokio::net::TcpStream; debug!(%self.url, %proxy_url, "dial url via proxy"); @@ -182,7 +189,7 @@ impl MaybeTlsStreamBuilder { .resolve_host(&proxy_url, self.prefer_ipv6, DNS_TIMEOUT) .await?; - let proxy_port = url_port(&proxy_url).ok_or_else(|| anyhow!("Missing proxy url port"))?; + let proxy_port = url_port(&proxy_url).context(ProxyInvalidTargetPortSnafu)?; let proxy_addr = SocketAddr::new(proxy_ip, proxy_port); debug!(%proxy_addr, "connecting to proxy"); @@ -190,9 +197,7 @@ impl MaybeTlsStreamBuilder { let tcp_stream = time::timeout(DIAL_NODE_TIMEOUT, async move { TcpStream::connect(proxy_addr).await }) - .await - .context("Timeout connecting")? - .context("Connecting")?; + .await??; tcp_stream.set_nodelay(true)?; @@ -200,19 +205,26 @@ impl MaybeTlsStreamBuilder { let io = if proxy_url.scheme() == "http" { MaybeTlsStream::Raw(tcp_stream) } else { - let hostname = proxy_url.host_str().context("No hostname in proxy URL")?; - let hostname = rustls::pki_types::ServerName::try_from(hostname.to_string())?; + let hostname = proxy_url.host_str().context(ProxyInvalidUrlSnafu { + proxy_url: proxy_url.clone(), + })?; + let hostname = + rustls::pki_types::ServerName::try_from(hostname.to_string()).map_err(|_| { + ProxyInvalidTlsServernameSnafu { + proxy_hostname: hostname.to_string(), + } + .build() + })?; let tls_stream = tls_connector.connect(hostname, tcp_stream).await?; MaybeTlsStream::Tls(tls_stream) }; let io = TokioIo::new(io); - let target_host = self - .url - .host_str() - .ok_or_else(|| anyhow!("Missing proxy host"))?; + let target_host = self.url.host_str().context(InvalidUrlSnafu { + url: self.url.clone(), + })?; - let port = url_port(&self.url).ok_or_else(|| anyhow!("invalid target port"))?; + let port = url_port(&self.url).context(InvalidTargetPortSnafu)?; // Establish Proxy Tunnel let mut req_builder = Request::builder() @@ -235,28 +247,33 @@ impl MaybeTlsStreamBuilder { let encoded = BASE64URL.encode(to_encode.as_bytes()); req_builder = req_builder.header("Proxy-Authorization", format!("Basic {}", encoded)); } - let req = req_builder.body(Empty::::new())?; + let req = req_builder + .body(Empty::::new()) + .expect("fixed config"); debug!("Sending proxy request: {:?}", req); - let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io) + .await + .context(ProxyConnectSnafu)?; task::spawn(async move { if let Err(err) = conn.with_upgrades().await { error!("Proxy connection failed: {:?}", err); } }); - let res = sender.send_request(req).await?; + let res = sender.send_request(req).await.context(ProxyConnectSnafu)?; if !res.status().is_success() { - bail!("Failed to connect to proxy: {}", res.status()); + return Err(ProxyConnectInvalidStatusSnafu { + status: res.status(), + } + .build()); } - let upgraded = hyper::upgrade::on(res).await?; - let Ok(Parts { io, read_buf, .. }) = - upgraded.downcast::>>() - else { - bail!("Invalid upgrade"); - }; + let upgraded = hyper::upgrade::on(res).await.context(ProxyConnectSnafu)?; + let Parts { io, read_buf, .. } = upgraded + .downcast::>>() + .expect("only this upgrade used"); let res = util::chain(std::io::Cursor::new(read_buf), io.into_inner()); @@ -269,7 +286,7 @@ impl ClientBuilder { /// set to [`HTTP_UPGRADE_PROTOCOL`]. /// /// [`HTTP_UPGRADE_PROTOCOL`]: crate::http::HTTP_UPGRADE_PROTOCOL - pub(super) async fn connect_relay(&self) -> Result<(Conn, SocketAddr)> { + pub(super) async fn connect_relay(&self) -> Result<(Conn, SocketAddr), ConnectError> { #[allow(unused_mut)] let mut builder = MaybeTlsStreamBuilder::new(self.url.clone().into(), self.dns_resolver.clone()) @@ -282,31 +299,31 @@ impl ClientBuilder { } let stream = builder.connect().await?; - let local_addr = stream.as_ref().local_addr()?; + let local_addr = stream + .as_ref() + .local_addr() + .map_err(|_| NoLocalAddrSnafu.build())?; let response = self.http_upgrade_relay(stream).await?; if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { - bail!( - "Unexpected status code: expected {}, actual: {}", - hyper::StatusCode::SWITCHING_PROTOCOLS, - response.status(), - ); + UnexpectedUpgradeStatusSnafu { + code: response.status(), + } + .fail()?; } debug!("starting upgrade"); - let upgraded = hyper::upgrade::on(response) - .await - .context("Upgrade failed")?; + let upgraded = hyper::upgrade::on(response).await.context(UpgradeSnafu)?; debug!("connection upgraded"); - let conn = downcast_upgrade(upgraded)?; + let conn = downcast_upgrade(upgraded).expect("must use TcpStream or client::TlsStream"); let conn = Conn::new_relay(conn, self.key_cache.clone(), &self.secret_key).await?; Ok((conn, local_addr)) } - pub(super) async fn connect_ws(&self) -> Result<(Conn, SocketAddr)> { + pub(super) async fn connect_ws(&self) -> Result<(Conn, SocketAddr), ConnectError> { let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); // The relay URL is exchanged with the http(s) scheme in tickets and similar. @@ -317,7 +334,12 @@ impl ClientBuilder { "ws" => "ws", _ => "wss", }) - .map_err(|()| anyhow!("Invalid URL"))?; + .map_err(|_| { + InvalidWebsocketUrlSnafu { + url: dial_url.clone(), + } + .build() + })?; debug!(%dial_url, "Dialing relay by websocket"); @@ -332,18 +354,26 @@ impl ClientBuilder { } let stream = builder.connect().await?; - let local_addr = stream.as_ref().local_addr()?; + let local_addr = stream + .as_ref() + .local_addr() + .map_err(|_| NoLocalAddrSnafu.build())?; let (conn, response) = tokio_websockets::ClientBuilder::new() - .uri(dial_url.as_str())? + .uri(dial_url.as_str()) + .map_err(|_| { + InvalidRelayUrlSnafu { + url: dial_url.clone(), + } + .build() + })? .connect_on(stream) .await?; if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { - bail!( - "Unexpected status code: expected {}, actual: {}", - hyper::StatusCode::SWITCHING_PROTOCOLS, - response.status(), - ); + UnexpectedUpgradeStatusSnafu { + code: response.status(), + } + .fail()?; } let conn = Conn::new_ws(conn, self.key_cache.clone(), &self.secret_key).await?; @@ -352,17 +382,21 @@ impl ClientBuilder { } /// Sends the HTTP upgrade request to the relay server. - async fn http_upgrade_relay(&self, io: T) -> Result> + async fn http_upgrade_relay(&self, io: T) -> Result, ConnectError> where T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { use hyper_util::rt::TokioIo; - let host_header_value = host_header_value(self.url.clone())?; + let host_header_value = + host_header_value(self.url.clone()).context(InvalidRelayUrlSnafu { + url: Url::from(self.url.clone()), + })?; let io = TokioIo::new(io); let (mut request_sender, connection) = hyper::client::conn::http1::Builder::new() .handshake(io) - .await?; + .await + .context(UpgradeSnafu)?; task::spawn( // This task drives the HTTP exchange, completes once connection is upgraded. async move { @@ -382,8 +416,9 @@ impl ClientBuilder { // > A client MUST include a Host header field in all HTTP/1.1 request messages. // This header value helps reverse proxies identify how to forward requests. .header(HOST, host_header_value) - .body(http_body_util::Empty::::new())?; - request_sender.send_request(req).await.map_err(From::from) + .body(http_body_util::Empty::::new()) + .expect("fixed config"); + request_sender.send_request(req).await.context(UpgradeSnafu) } /// Reports whether IPv4 dials should be slightly @@ -399,9 +434,11 @@ impl ClientBuilder { } } -fn host_header_value(relay_url: RelayUrl) -> Result { +/// Returns none if no valid url host was found. +fn host_header_value(relay_url: RelayUrl) -> Option { // grab the host, turns e.g. https://example.com:8080/xyz -> example.com. - let relay_url_host = relay_url.host_str().context("Invalid URL")?; + let relay_url_host = relay_url.host_str()?; + // strip the trailing dot, if present: example.com. -> example.com let relay_url_host = relay_url_host.strip_suffix('.').unwrap_or(relay_url_host); // build the host header value (reserve up to 6 chars for the ":" and port digits): @@ -411,7 +448,7 @@ fn host_header_value(relay_url: RelayUrl) -> Result { host_header_value += ":"; host_header_value += &port.to_string(); } - Ok(host_header_value) + Some(host_header_value) } fn url_port(url: &Url) -> Option { @@ -430,14 +467,14 @@ fn url_port(url: &Url) -> Option { mod tests { use std::str::FromStr; - use anyhow::Result; + use n0_snafu::Result; use tracing_test::traced_test; use super::*; #[test] #[traced_test] - fn test_host_header_value() -> Result<()> { + fn test_host_header_value() -> Result { let cases = [ ( "https://euw1-1.relay.iroh.network.", @@ -448,7 +485,7 @@ mod tests { for (url, expected_host) in cases { let relay_url = RelayUrl::from_str(url)?; - let host = host_header_value(relay_url)?; + let host = host_header_value(relay_url).unwrap(); assert_eq!(host, expected_host); } diff --git a/iroh-relay/src/dns.rs b/iroh-relay/src/dns.rs index 1d179393b2d..39803744ae3 100644 --- a/iroh-relay/src/dns.rs +++ b/iroh-relay/src/dns.rs @@ -1,27 +1,90 @@ //! DNS resolver use std::{ - fmt::{self, Write}, + fmt, future::Future, net::{IpAddr, Ipv6Addr, SocketAddr}, }; -use anyhow::{bail, Context, Result}; use hickory_resolver::{name_server::TokioConnectionProvider, TokioResolver}; use iroh_base::NodeId; use n0_future::{ time::{self, Duration}, StreamExt, }; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, OptionExt, Snafu}; use url::Url; -use crate::node_info::NodeInfo; +use crate::node_info::{LookupError, NodeInfo}; /// The n0 testing DNS node origin, for production. pub const N0_DNS_NODE_ORIGIN_PROD: &str = "dns.iroh.link"; /// The n0 testing DNS node origin, for testing. pub const N0_DNS_NODE_ORIGIN_STAGING: &str = "staging-dns.iroh.link"; +/// Potential errors related to dns. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +#[snafu(visibility(pub(crate)))] +pub enum DnsError { + #[snafu(transparent)] + Timeout { source: tokio::time::error::Elapsed }, + #[snafu(display("No response"))] + NoResponse {}, + #[snafu(display("Resolve failed ipv4: {ipv4}, ipv6 {ipv6}"))] + ResolveBoth { + ipv4: Box, + ipv6: Box, + }, + #[snafu(display("missing host"))] + MissingHost {}, + #[snafu(transparent)] + Resolve { + source: hickory_resolver::ResolveError, + }, + #[snafu(display("invalid DNS response: not a query for _iroh.z32encodedpubkey"))] + InvalidResponse {}, + #[snafu(display("no calls succeeded: [{}]", errors.iter().map(|e| e.to_string()).collect::>().join("")))] + Staggered { errors: Vec }, +} + +// TODO(matheus23): Remove this, or remove DnsError::Staggered, but don't keep both. +mod staggered { + use snafu::{Backtrace, GenerateImplicitData, Snafu}; + + use super::LookupError; + + /// Error returned when an input value is too long for [`crate::node_info::UserData`]. + #[allow(missing_docs)] + #[derive(Debug, Snafu)] + #[snafu(display("no calls succeeded: [{}]", errors.iter().map(|e| e.to_string()).collect::>().join("")))] + pub struct Error { + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + errors: Vec, + } + + impl Error { + pub(crate) fn new(errors: Vec) -> Self { + Self { + errors, + backtrace: GenerateImplicitData::generate(), + span_trace: n0_snafu::SpanTrace::generate(), + } + } + } +} + +pub use staggered::*; + /// The DNS resolver used throughout `iroh`. #[derive(Debug, Clone)] pub struct DnsResolver(TokioResolver); @@ -80,7 +143,11 @@ impl DnsResolver { } /// Lookup a TXT record. - pub async fn lookup_txt(&self, host: impl ToString, timeout: Duration) -> Result { + pub async fn lookup_txt( + &self, + host: impl ToString, + timeout: Duration, + ) -> Result { let host = host.to_string(); let res = time::timeout(timeout, self.0.txt_lookup(host)).await??; Ok(TxtLookup(res)) @@ -91,7 +158,7 @@ impl DnsResolver { &self, host: impl ToString, timeout: Duration, - ) -> Result> { + ) -> Result, DnsError> { let host = host.to_string(); let addrs = time::timeout(timeout, self.0.ipv4_lookup(host)).await??; Ok(addrs.into_iter().map(|ip| IpAddr::V4(ip.0))) @@ -102,7 +169,7 @@ impl DnsResolver { &self, host: impl ToString, timeout: Duration, - ) -> Result> { + ) -> Result, DnsError> { let host = host.to_string(); let addrs = time::timeout(timeout, self.0.ipv6_lookup(host)).await??; Ok(addrs.into_iter().map(|ip| IpAddr::V6(ip.0))) @@ -117,7 +184,7 @@ impl DnsResolver { &self, host: impl ToString, timeout: Duration, - ) -> Result> { + ) -> Result, DnsError> { let host = host.to_string(); let res = tokio::join!( self.lookup_ipv4(host.clone(), timeout), @@ -128,9 +195,11 @@ impl DnsResolver { (Ok(ipv4), Ok(ipv6)) => Ok(LookupIter::Both(ipv4.chain(ipv6))), (Ok(ipv4), Err(_)) => Ok(LookupIter::Ipv4(ipv4)), (Err(_), Ok(ipv6)) => Ok(LookupIter::Ipv6(ipv6)), - (Err(ipv4_err), Err(ipv6_err)) => { - bail!("Ipv4: {:?}, Ipv6: {:?}", ipv4_err, ipv6_err) + (Err(ipv4_err), Err(ipv6_err)) => Err(ResolveBothSnafu { + ipv4: Box::new(ipv4_err), + ipv6: Box::new(ipv6_err), } + .build()), } } @@ -140,8 +209,8 @@ impl DnsResolver { url: &Url, prefer_ipv6: bool, timeout: Duration, - ) -> Result { - let host = url.host().context("Invalid URL")?; + ) -> Result { + let host = url.host().context(MissingHostSnafu)?; match host { url::Host::Domain(domain) => { // Need to do a DNS lookup @@ -151,13 +220,21 @@ impl DnsResolver { ); let (v4, v6) = match lookup { (Err(ipv4_err), Err(ipv6_err)) => { - bail!("Ipv4: {ipv4_err:?}, Ipv6: {ipv6_err:?}"); + return Err(ResolveBothSnafu { + ipv4: Box::new(ipv4_err), + ipv6: Box::new(ipv6_err), + } + .build()); } (Err(_), Ok(mut v6)) => (None, v6.next()), (Ok(mut v4), Err(_)) => (v4.next(), None), (Ok(mut v4), Ok(mut v6)) => (v4.next(), v6.next()), }; - if prefer_ipv6 { v6.or(v4) } else { v4.or(v6) }.context("No response") + if prefer_ipv6 { + v6.or(v4).context(NoResponseSnafu) + } else { + v4.or(v6).context(NoResponseSnafu) + } } url::Host::Ipv4(ip) => Ok(IpAddr::V4(ip)), url::Host::Ipv6(ip) => Ok(IpAddr::V6(ip)), @@ -175,10 +252,12 @@ impl DnsResolver { host: impl ToString, timeout: Duration, delays_ms: &[u64], - ) -> Result> { + ) -> Result, DnsError> { let host = host.to_string(); let f = || self.lookup_ipv4(host.clone(), timeout); - stagger_call(f, delays_ms).await + stagger_call(f, delays_ms) + .await + .map_err(|errors| StaggeredSnafu { errors }.build()) } /// Perform an ipv6 lookup with a timeout in a staggered fashion. @@ -192,10 +271,12 @@ impl DnsResolver { host: impl ToString, timeout: Duration, delays_ms: &[u64], - ) -> Result> { + ) -> Result, DnsError> { let host = host.to_string(); let f = || self.lookup_ipv6(host.clone(), timeout); - stagger_call(f, delays_ms).await + stagger_call(f, delays_ms) + .await + .map_err(|errors| StaggeredSnafu { errors }.build()) } /// Race an ipv4 and ipv6 lookup with a timeout in a staggered fashion. @@ -210,17 +291,23 @@ impl DnsResolver { host: impl ToString, timeout: Duration, delays_ms: &[u64], - ) -> Result> { + ) -> Result, DnsError> { let host = host.to_string(); let f = || self.lookup_ipv4_ipv6(host.clone(), timeout); - stagger_call(f, delays_ms).await + stagger_call(f, delays_ms) + .await + .map_err(|errors| StaggeredSnafu { errors }.build()) } /// Looks up node info by [`NodeId`] and origin domain name. /// /// To lookup nodes that published their node info to the DNS servers run by n0, /// pass [`N0_DNS_NODE_ORIGIN_PROD`] as `origin`. - pub async fn lookup_node_by_id(&self, node_id: &NodeId, origin: &str) -> Result { + pub async fn lookup_node_by_id( + &self, + node_id: &NodeId, + origin: &str, + ) -> Result { let attrs = crate::node_info::TxtAttrs::::lookup_by_id( self, node_id, origin, ) @@ -230,7 +317,7 @@ impl DnsResolver { } /// Looks up node info by DNS name. - pub async fn lookup_node_by_domain_name(&self, name: &str) -> Result { + pub async fn lookup_node_by_domain_name(&self, name: &str) -> Result { let attrs = crate::node_info::TxtAttrs::::lookup_by_name(self, name) .await?; @@ -248,9 +335,11 @@ impl DnsResolver { &self, name: &str, delays_ms: &[u64], - ) -> Result { + ) -> Result { let f = || self.lookup_node_by_domain_name(name); - stagger_call(f, delays_ms).await + stagger_call(f, delays_ms) + .await + .map_err(staggered::Error::new) } /// Looks up node info by [`NodeId`] and origin domain name. @@ -264,9 +353,11 @@ impl DnsResolver { node_id: &NodeId, origin: &str, delays_ms: &[u64], - ) -> Result { + ) -> Result { let f = || self.lookup_node_by_id(node_id, origin); - stagger_call(f, delays_ms).await + stagger_call(f, delays_ms) + .await + .map_err(staggered::Error::new) } } @@ -356,10 +447,10 @@ impl, B: Iterator> Iterator for Lookup /// /// The first call is performed immediately. The first call to succeed generates an Ok result /// ignoring any previous error. If all calls fail, an error summarizing all errors is returned. -async fn stagger_call Fut, Fut: Future>>( +async fn stagger_call Fut, Fut: Future>>( f: F, delays_ms: &[u64], -) -> Result { +) -> Result> { let mut calls = n0_future::FuturesUnorderedBounded::new(delays_ms.len() + 1); // NOTE: we add the 0 delay here to have a uniform set of futures. This is more performant than // using alternatives that allow futures of different types. @@ -381,13 +472,7 @@ async fn stagger_call Fut, Fut: Future>>( } } - bail!( - "no calls succeed: [ {}]", - errors.into_iter().fold(String::new(), |mut summary, e| { - write!(summary, "{e} ").expect("infallible"); - summary - }) - ) + Err(errors) } #[cfg(test)] @@ -407,7 +492,7 @@ pub(crate) mod tests { let r_pos = DONE_CALL.fetch_add(1, std::sync::atomic::Ordering::Relaxed); async move { tracing::info!(r_pos, "call"); - CALL_RESULTS[r_pos].map_err(|e| anyhow::anyhow!("{e}")) + CALL_RESULTS[r_pos].map_err(|_| InvalidResponseSnafu.build()) } }; diff --git a/iroh-relay/src/main.rs b/iroh-relay/src/main.rs index 80ff1401f02..3aa4da21546 100644 --- a/iroh-relay/src/main.rs +++ b/iroh-relay/src/main.rs @@ -5,11 +5,11 @@ use std::{ net::{Ipv6Addr, SocketAddr}, + num::NonZeroU32, path::{Path, PathBuf}, sync::Arc, }; -use anyhow::{anyhow, bail, Context as _, Result}; use clap::Parser; use http::StatusCode; use iroh_base::NodeId; @@ -21,7 +21,9 @@ use iroh_relay::{ server::{self as relay, ClientRateLimit, QuicConfig}, }; use n0_future::FutureExt; +use n0_snafu::{Error, Result, ResultExt}; use serde::{Deserialize, Serialize}; +use snafu::whatever; use tokio_rustls_acme::{caches::DirCache, AcmeConfig}; use tracing::{debug, warn}; use tracing_subscriber::{prelude::*, EnvFilter}; @@ -68,7 +70,7 @@ fn load_certs( let mut reader = std::io::BufReader::new(certfile); let certs: Result, std::io::Error> = rustls_pemfile::certs(&mut reader).collect(); - let certs = certs?; + let certs = certs.context("reading cert")?; Ok(certs) } @@ -97,7 +99,7 @@ fn load_secret_key( } } - bail!( + whatever!( "no keys found in {} (encrypted keys not supported)", filename.display() ); @@ -299,10 +301,10 @@ async fn http_access_check_inner( } Ok(res) if res.status() == StatusCode::OK => match res.text().await { Ok(text) if text == "true" => Ok(()), - Ok(_) => Err(anyhow!("Invalid response text (must be 'true')")), + Ok(_) => whatever!("Invalid response text (must be 'true')"), Err(err) => Err(err).context("Failed to read response"), }, - Ok(res) => Err(anyhow!("Received invalid status code ({})", res.status())), + Ok(res) => whatever!("Received invalid status code ({})", res.status()), } } @@ -518,7 +520,7 @@ impl Config { async fn read_from_file(path: impl AsRef) -> Result { if !path.as_ref().is_file() { - bail!("config-path must be a file"); + whatever!("config-path must be a file"); } let config_ser = tokio::fs::read_to_string(&path) .await @@ -537,7 +539,7 @@ async fn main() -> Result<()> { let cli = Cli::parse(); let mut cfg = Config::load(&cli).await?; if cfg.enable_quic_addr_discovery && cfg.tls.is_none() { - bail!("TLS must be configured in order to spawn a QUIC endpoint"); + whatever!("TLS must be configured in order to spawn a QUIC endpoint"); } if cli.dev { // When in `--dev` mode, do not use https, even when tls is configured. @@ -549,7 +551,7 @@ async fn main() -> Result<()> { } } if cfg.tls.is_none() && cfg.enable_quic_addr_discovery { - bail!("If QUIC address discovery is enabled, TLS must also be configured"); + whatever!("If QUIC address discovery is enabled, TLS must also be configured"); }; let relay_config = build_relay_config(cfg).await?; debug!("{relay_config:#?}"); @@ -562,7 +564,8 @@ async fn main() -> Result<()> { _ = relay.task_handle() => (), } - relay.shutdown().await + relay.shutdown().await?; + Ok(()) } async fn maybe_load_tls( @@ -585,10 +588,13 @@ async fn maybe_load_tls( let (private_key, certs) = tokio::task::spawn_blocking(move || { let key = load_secret_key(key_path)?; let certs = load_certs(cert_path)?; - anyhow::Ok((key, certs)) + Ok::<_, Error>((key, certs)) }) - .await??; - let server_config = server_config.with_single_cert(certs.clone(), private_key)?; + .await + .context("join")??; + let server_config = server_config + .with_single_cert(certs.clone(), private_key) + .context("tls config")?; (relay::CertConfig::Manual { certs }, server_config) } CertMode::LetsEncrypt => { @@ -638,7 +644,11 @@ async fn maybe_load_tls( certs_reader, }; - let resolver = Arc::new(relay::ReloadingResolver::init(loader, interval).await?); + let resolver = Arc::new( + relay::ReloadingResolver::init(loader, interval) + .await + .context("cert loading")?, + ); let server_config = server_config.with_cert_resolver(resolver); (relay::CertConfig::Reloading, server_config) } @@ -667,7 +677,7 @@ async fn build_relay_config(cfg: Config) -> Result Result { if rx.bytes_per_second.is_none() && rx.max_burst_bytes.is_some() { - bail!("bytes_per_seconds must be specified to enable the rate-limiter"); + whatever!("bytes_per_seconds must be specified to enable the rate-limiter"); } match rx.bytes_per_second { Some(bps) => Some(ClientRateLimit { - bytes_per_second: bps - .try_into() + bytes_per_second: TryInto::::try_into(bps) .context("bytes_per_second must be non-zero u32")?, max_burst_bytes: rx .max_burst_bytes .map(|v| { - v.try_into().context("max_burst_bytes must be non-zero u32") + TryInto::::try_into(v) + .context("max_burst_bytes must be non-zero u32") }) .transpose()?, }), @@ -729,14 +739,14 @@ mod tests { use std::num::NonZeroU32; use iroh_base::SecretKey; + use n0_snafu::Result; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; - use testresult::TestResult; use super::*; #[tokio::test] - async fn test_rate_limit_config() -> TestResult { + async fn test_rate_limit_config() -> Result { let config = " [limits.client.rx] bytes_per_second = 400 @@ -759,7 +769,7 @@ mod tests { } #[tokio::test] - async fn test_rate_limit_default() -> TestResult { + async fn test_rate_limit_default() -> Result { let config = Config::from_str("")?; let relay_config = build_relay_config(config).await?; @@ -770,7 +780,7 @@ mod tests { } #[tokio::test] - async fn test_access_config() -> TestResult { + async fn test_access_config() -> Result { let config = " access = \"everyone\" "; diff --git a/iroh-relay/src/node_info.rs b/iroh-relay/src/node_info.rs index 47e082fadfb..e4cc9aa2cec 100644 --- a/iroh-relay/src/node_info.rs +++ b/iroh-relay/src/node_info.rs @@ -37,23 +37,63 @@ use std::{ fmt::{self, Display}, hash::Hash, net::SocketAddr, - str::FromStr, + str::{FromStr, Utf8Error}, }; -use anyhow::{anyhow, Result}; #[cfg(not(wasm_browser))] use hickory_resolver::{proto::ProtoError, Name}; -use iroh_base::{NodeAddr, NodeId, RelayUrl, SecretKey}; +use iroh_base::{NodeAddr, NodeId, RelayUrl, SecretKey, SignatureError}; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; #[cfg(not(wasm_browser))] use tracing::warn; use url::Url; #[cfg(not(wasm_browser))] -use crate::{defaults::timeouts::DNS_TIMEOUT, dns::DnsResolver}; +use crate::{ + defaults::timeouts::DNS_TIMEOUT, + dns::{DnsError, DnsResolver}, +}; /// The DNS name for the iroh TXT record. pub const IROH_TXT_NAME: &str = "_iroh"; +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +#[snafu(visibility(pub(crate)))] +pub enum EncodingError { + #[snafu(transparent)] + FailedBuildingPacket { + source: pkarr::errors::SignedPacketBuildError, + }, + #[snafu(display("invalid TXT entry"))] + InvalidTxtEntry { source: pkarr::dns::SimpleDnsError }, +} + +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +#[snafu(visibility(pub(crate)))] +pub enum DecodingError { + #[snafu(display("node id was not encoded in valid z32"))] + InvalidEncodingZ32 { source: z32::Z32Error }, + #[snafu(display("length must be 32 bytes, but got {len} byte(s)"))] + InvalidLength { len: usize }, + #[snafu(display("node id is not a valid public key"))] + InvalidSignature { source: SignatureError }, +} + /// Extension methods for [`NodeId`] to encode to and decode from [`z32`], /// which is the encoding used in [`pkarr`] domain names. pub trait NodeIdExt { @@ -65,7 +105,7 @@ pub trait NodeIdExt { /// Parses a [`NodeId`] from [`z-base-32`] encoding. /// /// [z-base-32]: https://philzimmermann.com/docs/human-oriented-base-32-encoding.txt - fn from_z32(s: &str) -> Result; + fn from_z32(s: &str) -> Result; } impl NodeIdExt for NodeId { @@ -73,10 +113,12 @@ impl NodeIdExt for NodeId { z32::encode(self.as_bytes()) } - fn from_z32(s: &str) -> Result { - let bytes = z32::decode(s.as_bytes()).map_err(|_| anyhow!("invalid z32"))?; - let bytes: &[u8; 32] = &bytes.try_into().map_err(|_| anyhow!("not 32 bytes long"))?; - let node_id = NodeId::from_bytes(bytes)?; + fn from_z32(s: &str) -> Result { + let bytes = z32::decode(s.as_bytes()).context(InvalidEncodingZ32Snafu)?; + let bytes: &[u8; 32] = &bytes + .try_into() + .map_err(|_| InvalidLengthSnafu { len: s.len() }.build())?; + let node_id = NodeId::from_bytes(bytes).context(InvalidSignatureSnafu)?; Ok(node_id) } } @@ -194,19 +236,20 @@ impl UserData { } /// Error returned when an input value is too long for [`UserData`]. -#[derive(Debug, thiserror::Error)] -#[error("User-defined data exceeds max length")] -pub struct MaxLengthExceededError; +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +pub struct MaxLengthExceededError { + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +} impl TryFrom for UserData { type Error = MaxLengthExceededError; fn try_from(value: String) -> Result { - if value.len() > Self::MAX_LENGTH { - Err(MaxLengthExceededError) - } else { - Ok(Self(value)) - } + snafu::ensure!(value.len() <= Self::MAX_LENGTH, MaxLengthExceededSnafu); + Ok(Self(value)) } } @@ -214,11 +257,8 @@ impl FromStr for UserData { type Err = MaxLengthExceededError; fn from_str(s: &str) -> std::result::Result { - if s.len() > Self::MAX_LENGTH { - Err(MaxLengthExceededError) - } else { - Ok(Self(s.to_string())) - } + snafu::ensure!(s.len() <= Self::MAX_LENGTH, MaxLengthExceededSnafu); + Ok(Self(s.to_string())) } } @@ -349,13 +389,13 @@ impl NodeInfo { #[cfg(not(wasm_browser))] /// Parses a [`NodeInfo`] from a TXT records lookup. - pub fn from_txt_lookup(lookup: crate::dns::TxtLookup) -> Result { + pub fn from_txt_lookup(lookup: crate::dns::TxtLookup) -> Result { let attrs = TxtAttrs::from_txt_lookup(lookup)?; Ok(attrs.into()) } /// Parses a [`NodeInfo`] from a [`pkarr::SignedPacket`]. - pub fn from_pkarr_signed_packet(packet: &pkarr::SignedPacket) -> Result { + pub fn from_pkarr_signed_packet(packet: &pkarr::SignedPacket) -> Result { let attrs = TxtAttrs::from_pkarr_signed_packet(packet)?; Ok(attrs.into()) } @@ -367,7 +407,7 @@ impl NodeInfo { &self, secret_key: &SecretKey, ttl: u32, - ) -> Result { + ) -> Result { self.to_attrs().to_pkarr_signed_packet(secret_key, ttl) } @@ -377,6 +417,59 @@ impl NodeInfo { } } +#[cfg(not(wasm_browser))] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +#[snafu(visibility(pub(crate)))] +pub enum LookupError { + #[snafu(display("Malformed txt from lookup"))] + ParseError { + #[snafu(source(from(ParseError, Box::new)))] + source: Box, + }, + #[snafu(display("Failed to resolve TXT record"))] + LookupFailed { + #[snafu(source(from(DnsError, Box::new)))] + source: Box, + }, + #[cfg(not(wasm_browser))] + #[snafu(display("Name is not a valid TXT label"))] + InvalidLabel { source: ProtoError }, +} + +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +#[snafu(visibility(pub(crate)))] +pub enum ParseError { + #[snafu(display("Expected format `key=value`, received `{s}`"))] + UnexpectedFormat { s: String }, + #[snafu(display("Could not convert key to Attr"))] + AttrFromString { key: String }, + #[snafu(display("Expected 2 labels, received {num_labels}"))] + NumLabels { num_labels: usize }, + #[snafu(display("Could not parse labels"))] + Utf8 { source: Utf8Error }, + #[snafu(display("Record is not an `iroh` record, expected `_iroh`, got `{label}`"))] + NotAnIrohRecord { label: String }, + #[snafu(transparent)] + DecodingError { + #[snafu(source(from(DecodingError, Box::new)))] + source: Box, + }, +} + impl std::ops::Deref for NodeInfo { type Target = NodeData; fn deref(&self) -> &Self::Target { @@ -396,18 +489,25 @@ impl std::ops::DerefMut for NodeInfo { /// [`IROH_TXT_NAME`] and the second label to be a z32 encoded [`NodeId`]. Ignores /// subsequent labels. #[cfg(not(wasm_browser))] -fn node_id_from_hickory_name(name: &hickory_resolver::proto::rr::Name) -> Option { +fn node_id_from_hickory_name( + name: &hickory_resolver::proto::rr::Name, +) -> Result { if name.num_labels() < 2 { - return None; + return Err(NumLabelsSnafu { + num_labels: name.num_labels(), + } + .build()); } let mut labels = name.iter(); - let label = std::str::from_utf8(labels.next().expect("num_labels checked")).ok()?; + let label = + std::str::from_utf8(labels.next().expect("num_labels checked")).context(Utf8Snafu)?; if label != IROH_TXT_NAME { - return None; + return Err(NotAnIrohRecordSnafu { label }.build()); } - let label = std::str::from_utf8(labels.next().expect("num_labels checked")).ok()?; - let node_id = NodeId::from_z32(label).ok()?; - Some(node_id) + let label = + std::str::from_utf8(labels.next().expect("num_labels checked")).context(Utf8Snafu)?; + let node_id = NodeId::from_z32(label)?; + Ok(node_id) } /// The attributes supported by iroh for [`IROH_TXT_NAME`] DNS resource records. @@ -467,26 +567,27 @@ impl TxtAttrs { pub(crate) fn from_strings( node_id: NodeId, strings: impl Iterator, - ) -> Result { + ) -> Result { let mut attrs: BTreeMap> = BTreeMap::new(); for s in strings { let mut parts = s.split('='); let (Some(key), Some(value)) = (parts.next(), parts.next()) else { - continue; - }; - let Ok(attr) = T::from_str(key) else { - continue; + return Err(UnexpectedFormatSnafu { s }.build()); }; + let attr = T::from_str(key).map_err(|_| AttrFromStringSnafu { key }.build())?; attrs.entry(attr).or_default().push(value.to_string()); } Ok(Self { attrs, node_id }) } #[cfg(not(wasm_browser))] - async fn lookup(resolver: &DnsResolver, name: Name) -> Result { + async fn lookup(resolver: &DnsResolver, name: Name) -> Result { let name = ensure_iroh_txt_label(name)?; - let lookup = resolver.lookup_txt(name, DNS_TIMEOUT).await?; - let attrs = Self::from_txt_lookup(lookup)?; + let lookup = resolver + .lookup_txt(name, DNS_TIMEOUT) + .await + .context(LookupFailedSnafu)?; + let attrs = Self::from_txt_lookup(lookup).context(ParseSnafu)?; Ok(attrs) } @@ -496,15 +597,18 @@ impl TxtAttrs { resolver: &DnsResolver, node_id: &NodeId, origin: &str, - ) -> Result { + ) -> Result { let name = node_domain(node_id, origin)?; TxtAttrs::lookup(resolver, name).await } /// Looks up attributes by DNS name. #[cfg(not(wasm_browser))] - pub(crate) async fn lookup_by_name(resolver: &DnsResolver, name: &str) -> Result { - let name = Name::from_str(name)?; + pub(crate) async fn lookup_by_name( + resolver: &DnsResolver, + name: &str, + ) -> Result { + let name = Name::from_str(name).context(InvalidLabelSnafu)?; TxtAttrs::lookup(resolver, name).await } @@ -519,7 +623,9 @@ impl TxtAttrs { } /// Parses a [`pkarr::SignedPacket`]. - pub(crate) fn from_pkarr_signed_packet(packet: &pkarr::SignedPacket) -> Result { + pub(crate) fn from_pkarr_signed_packet( + packet: &pkarr::SignedPacket, + ) -> Result { use pkarr::dns::{ rdata::RData, {self}, @@ -527,7 +633,7 @@ impl TxtAttrs { let pubkey = packet.public_key(); let pubkey_z32 = pubkey.to_z32(); let node_id = NodeId::from(*pubkey.verifying_key()); - let zone = dns::Name::new(&pubkey_z32)?; + let zone = dns::Name::new(&pubkey_z32).expect("z32 encoding is valid"); let txt_data = packet .all_resource_records() .filter_map(|rr| match &rr.rdata { @@ -544,14 +650,13 @@ impl TxtAttrs { /// Parses a TXT records lookup. #[cfg(not(wasm_browser))] - pub(crate) fn from_txt_lookup(lookup: crate::dns::TxtLookup) -> Result { - let queried_node_id = node_id_from_hickory_name(lookup.0.query().name()) - .ok_or_else(|| anyhow!("invalid DNS answer: not a query for _iroh.z32encodedpubkey"))?; + pub(crate) fn from_txt_lookup(lookup: crate::dns::TxtLookup) -> Result { + let queried_node_id = node_id_from_hickory_name(lookup.0.query().name())?; let strings = lookup.0.as_lookup().record_iter().filter_map(|record| { match node_id_from_hickory_name(record.name()) { // Filter out only TXT record answers that match the node_id we searched for. - Some(n) if n == queried_node_id => match record.data().as_txt() { + Ok(n) if n == queried_node_id => match record.data().as_txt() { Some(txt) => Some(txt.to_string()), None => { warn!( @@ -562,7 +667,7 @@ impl TxtAttrs { None } }, - Some(answered_node_id) => { + Ok(answered_node_id) => { warn!( ?queried_node_id, ?answered_node_id, @@ -570,10 +675,11 @@ impl TxtAttrs { ); None } - None => { + Err(e) => { warn!( ?queried_node_id, name = ?record.name(), + err = ?e, "unexpected answer record name for DNS query" ); None @@ -597,15 +703,15 @@ impl TxtAttrs { &self, secret_key: &SecretKey, ttl: u32, - ) -> Result { + ) -> Result { use pkarr::dns::{self, rdata}; let keypair = pkarr::Keypair::from_secret_key(&secret_key.to_bytes()); - let name = dns::Name::new(IROH_TXT_NAME)?; + let name = dns::Name::new(IROH_TXT_NAME).expect("constant"); let mut builder = pkarr::SignedPacket::builder(); for s in self.to_txt_strings() { let mut txt = rdata::TXT::new(); - txt.add_string(&s)?; + txt.add_string(&s).context(InvalidTxtEntrySnafu)?; builder = builder.txt(name.clone(), txt.into_owned(), ttl); } let signed_packet = builder.build(&keypair)?; @@ -614,18 +720,18 @@ impl TxtAttrs { } #[cfg(not(wasm_browser))] -fn ensure_iroh_txt_label(name: Name) -> Result { +fn ensure_iroh_txt_label(name: Name) -> Result { if name.iter().next() == Some(IROH_TXT_NAME.as_bytes()) { Ok(name) } else { - Name::parse(IROH_TXT_NAME, Some(&name)) + Name::parse(IROH_TXT_NAME, Some(&name)).context(InvalidLabelSnafu) } } #[cfg(not(wasm_browser))] -fn node_domain(node_id: &NodeId, origin: &str) -> Result { +fn node_domain(node_id: &NodeId, origin: &str) -> Result { let domain = format!("{}.{}", NodeId::to_z32(node_id), origin); - let domain = Name::from_str(&domain)?; + let domain = Name::from_str(&domain).context(InvalidLabelSnafu)?; Ok(domain) } @@ -645,7 +751,7 @@ mod tests { Name, }; use iroh_base::{NodeId, SecretKey}; - use testresult::TestResult; + use n0_snafu::{Result, ResultExt}; use super::{NodeData, NodeIdExt, NodeInfo}; @@ -687,10 +793,11 @@ mod tests { /// The reason was that only the first address was parsed (e.g. 192.168.96.145 in /// this example), which could be a local, unreachable address. #[test] - fn test_from_hickory_lookup() -> TestResult { + fn test_from_hickory_lookup() -> Result { let name = Name::from_utf8( "_iroh.dgjpkxyn3zyrk3zfads5duwdgbqpkwbjxfj4yt7rezidr3fijccy.dns.iroh.link.", - )?; + ) + .context("dns name")?; let query = Query::query(name.clone(), RecordType::TXT); let records = [ Record::from_rdata( @@ -714,7 +821,8 @@ mod tests { "a55f26132e5e43de834d534332f66a20d480c3e50a13a312a071adea6569981e" )? .to_z32() - ))?, + )) + .context("name")?, 30, RData::TXT(TXT::new(vec![ "relay=https://euw1-1.relay.iroh.network./".to_string() @@ -722,7 +830,7 @@ mod tests { ), // Test a record with a completely different name Record::from_rdata( - Name::from_utf8("dns.iroh.link.")?, + Name::from_utf8("dns.iroh.link.").context("name")?, 30, RData::TXT(TXT::new(vec![ "relay=https://euw1-1.relay.iroh.network./".to_string() @@ -746,8 +854,8 @@ mod tests { )?) .with_relay_url(Some("https://euw1-1.relay.iroh.network./".parse()?)) .with_direct_addresses(BTreeSet::from([ - "192.168.96.145:60165".parse()?, - "213.208.157.87:60165".parse()?, + "192.168.96.145:60165".parse().unwrap(), + "213.208.157.87:60165".parse().unwrap(), ])); assert_eq!(node_info, expected_node_info); diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 76668e77714..6caa42d91c3 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -12,18 +12,19 @@ //! * clients sends `FrameType::SendPacket` //! * server then sends `FrameType::RecvPacket` to recipient -use anyhow::{bail, ensure}; use bytes::{BufMut, Bytes}; -use iroh_base::{PublicKey, SecretKey, Signature}; +use iroh_base::{PublicKey, SecretKey, Signature, SignatureError}; #[cfg(feature = "server")] use n0_future::time::Duration; -use n0_future::{Sink, SinkExt}; +use n0_future::{time, Sink, SinkExt}; #[cfg(any(test, feature = "server"))] use n0_future::{Stream, StreamExt}; +use nested_enum_utils::common_fields; use postcard::experimental::max_size::MaxSize; use serde::{Deserialize, Serialize}; +use snafu::{Backtrace, Snafu}; -use crate::{client::conn::ConnSendError, KeyCache}; +use crate::{client::conn::SendError as ConnSendError, KeyCache}; /// The maximum size of a packet sent over relay. /// (This only includes the data bytes visible to magicsock, not @@ -74,7 +75,7 @@ const NOT_PREFERRED: u8 = 0u8; /// length of the remaining frame (not including the initial 5 bytes) #[derive(Debug, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::FromPrimitive, Clone, Copy)] #[repr(u8)] -pub(crate) enum FrameType { +pub enum FrameType { /// magic + 32b pub key + 24B nonce + bytes ClientInfo = 2, /// 32B dest pub key + packet bytes @@ -110,6 +111,7 @@ pub(crate) enum FrameType { /// /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` Restarting = 15, + /// Unknown frame type #[num_enum(default)] Unknown = 255, } @@ -126,6 +128,56 @@ pub(crate) struct ClientInfo { pub(crate) version: usize, } +/// Protocol send errors. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum SendError { + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(transparent)] + Timeout { source: time::Elapsed }, + #[snafu(transparent)] + ConnSend { source: ConnSendError }, + #[snafu(transparent)] + SerDe { source: postcard::Error }, +} + +/// Protocol send errors. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum RecvError { + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(display("unexpected frame: got {got}, expected {expected}"))] + UnexpectedFrame { got: FrameType, expected: FrameType }, + #[snafu(display("Frame is too large, has {frame_len} bytes"))] + FrameTooLarge { frame_len: usize }, + #[snafu(transparent)] + Timeout { source: time::Elapsed }, + #[snafu(transparent)] + SerDe { source: postcard::Error }, + #[snafu(transparent)] + InvalidSignature { source: SignatureError }, + #[snafu(display("Invalid frame encoding"))] + InvalidFrame {}, + #[snafu(display("Invalid frame type: {frame_type}"))] + InvalidFrameType { frame_type: FrameType }, + #[snafu(display("Too few bytes"))] + TooSmall {}, +} + /// Writes complete frame, errors if it is unable to write within the given `timeout`. /// Ignores the timeout if `None` /// @@ -135,7 +187,7 @@ pub(crate) async fn write_frame + Unpin>( mut writer: S, frame: Frame, timeout: Option, -) -> anyhow::Result<()> { +) -> Result<(), SendError> { if let Some(duration) = timeout { tokio::time::timeout(duration, writer.send(frame)).await??; } else { @@ -153,7 +205,7 @@ pub(crate) async fn send_client_key + Unpi mut writer: S, client_secret_key: &SecretKey, client_info: &ClientInfo, -) -> anyhow::Result<()> { +) -> Result<(), SendError> { let msg = postcard::to_stdvec(client_info)?; let signature = client_secret_key.sign(&msg); @@ -171,10 +223,12 @@ pub(crate) async fn send_client_key + Unpi /// Reads the `FrameType::ClientInfo` frame from the client (its proof of identity) /// upon it's initial connection. #[cfg(any(test, feature = "server"))] -pub(crate) async fn recv_client_key> + Unpin>( +pub(crate) async fn recv_client_key> + Unpin>( stream: S, -) -> anyhow::Result<(PublicKey, ClientInfo)> { - use anyhow::Context; +) -> Result<(PublicKey, ClientInfo), E> +where + E: From, +{ // the client is untrusted at this point, limit the input size even smaller than our usual // maximum frame size, and give a timeout @@ -184,8 +238,7 @@ pub(crate) async fn recv_client_key> + Un recv_frame(FrameType::ClientInfo, stream), ) .await - .context("recv_frame timeout")? - .context("recv_frame")?; + .map_err(RecvError::from)??; if let Frame::ClientInfo { client_public_key, @@ -195,11 +248,17 @@ pub(crate) async fn recv_client_key> + Un { client_public_key .verify(&message, &signature) - .context("invalid signature")?; - let info: ClientInfo = postcard::from_bytes(&message).context("deserialization")?; + .map_err(RecvError::from)?; + + let info: ClientInfo = postcard::from_bytes(&message).map_err(RecvError::from)?; Ok((client_public_key, info)) } else { - anyhow::bail!("expected FrameType::ClientInfo"); + Err(UnexpectedFrameSnafu { + got: buf.typ(), + expected: FrameType::ClientInfo, + } + .build() + .into()) } } @@ -314,9 +373,10 @@ impl Frame { /// Tries to decode a frame received over websockets. /// /// Specifically, bytes received from a binary websocket message frame. - pub(crate) fn decode_from_ws_msg(bytes: Bytes, cache: &KeyCache) -> anyhow::Result { + #[allow(clippy::result_large_err)] + pub(crate) fn decode_from_ws_msg(bytes: Bytes, cache: &KeyCache) -> Result { if bytes.is_empty() { - bail!("error parsing relay::codec::Frame: too few bytes (0)"); + return Err(TooSmallSnafu.build()); } let typ = FrameType::from(bytes[0]); let frame = Self::from_bytes(typ, bytes.slice(1..), cache)?; @@ -384,18 +444,20 @@ impl Frame { } } - fn from_bytes(frame_type: FrameType, content: Bytes, cache: &KeyCache) -> anyhow::Result { + #[allow(clippy::result_large_err)] + fn from_bytes( + frame_type: FrameType, + content: Bytes, + cache: &KeyCache, + ) -> Result { let res = match frame_type { FrameType::ClientInfo => { - ensure!( - content.len() >= PublicKey::LENGTH + Signature::BYTE_SIZE + MAGIC.len(), - "invalid client info frame length: {}", - content.len() - ); - ensure!( - &content[..MAGIC.len()] == MAGIC.as_bytes(), - "invalid client info frame magic" - ); + if content.len() < PublicKey::LENGTH + Signature::BYTE_SIZE + MAGIC.len() { + return Err(InvalidFrameSnafu.build()); + } + if &content[..MAGIC.len()] != MAGIC.as_bytes() { + return Err(InvalidFrameSnafu.build()); + } let start = MAGIC.len(); let client_public_key = @@ -412,84 +474,94 @@ impl Frame { } } FrameType::SendPacket => { - ensure!( - content.len() >= PublicKey::LENGTH, - "invalid send packet frame length: {}", - content.len() - ); - let packet_len = content.len() - PublicKey::LENGTH; - ensure!( - packet_len <= MAX_PACKET_SIZE, - "data packet longer ({packet_len}) than max of {MAX_PACKET_SIZE}" - ); + if content.len() < PublicKey::LENGTH { + return Err(InvalidFrameSnafu.build()); + } + let frame_len = content.len() - PublicKey::LENGTH; + if frame_len > MAX_PACKET_SIZE { + return Err(FrameTooLargeSnafu { frame_len }.build()); + } + let dst_key = cache.key_from_slice(&content[..PublicKey::LENGTH])?; let packet = content.slice(PublicKey::LENGTH..); Self::SendPacket { dst_key, packet } } FrameType::RecvPacket => { - ensure!( - content.len() >= PublicKey::LENGTH, - "invalid recv packet frame length: {}", - content.len() - ); - let packet_len = content.len() - PublicKey::LENGTH; - ensure!( - packet_len <= MAX_PACKET_SIZE, - "data packet longer ({packet_len}) than max of {MAX_PACKET_SIZE}" - ); + if content.len() < PublicKey::LENGTH { + return Err(InvalidFrameSnafu.build()); + } + + let frame_len = content.len() - PublicKey::LENGTH; + if frame_len > MAX_PACKET_SIZE { + return Err(FrameTooLargeSnafu { frame_len }.build()); + } + let src_key = cache.key_from_slice(&content[..PublicKey::LENGTH])?; let content = content.slice(PublicKey::LENGTH..); Self::RecvPacket { src_key, content } } FrameType::KeepAlive => { - anyhow::ensure!(content.is_empty(), "invalid keep alive frame length"); + if !content.is_empty() { + return Err(InvalidFrameSnafu.build()); + } Self::KeepAlive } FrameType::NotePreferred => { - anyhow::ensure!(content.len() == 1, "invalid note preferred frame length"); + if content.len() != 1 { + return Err(InvalidFrameSnafu.build()); + } let preferred = match content[0] { PREFERRED => true, NOT_PREFERRED => false, - _ => anyhow::bail!("invalid note preferred frame content"), + _ => return Err(InvalidFrameSnafu.build()), }; Self::NotePreferred { preferred } } FrameType::PeerGone => { - anyhow::ensure!( - content.len() == PublicKey::LENGTH, - "invalid peer gone frame length" - ); + if content.len() != PublicKey::LENGTH { + return Err(InvalidFrameSnafu.build()); + } let peer = cache.key_from_slice(&content[..32])?; Self::NodeGone { node_id: peer } } FrameType::Ping => { - anyhow::ensure!(content.len() == 8, "invalid ping frame length"); + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); Self::Ping { data } } FrameType::Pong => { - anyhow::ensure!(content.len() == 8, "invalid pong frame length"); + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); Self::Pong { data } } FrameType::Health => Self::Health { problem: content }, FrameType::Restarting => { - ensure!( - content.len() == 4 + 4, - "invalid restarting frame length: {}", - content.len() + if content.len() != 4 + 4 { + return Err(InvalidFrameSnafu.build()); + } + let reconnect_in = u32::from_be_bytes( + content[..4] + .try_into() + .map_err(|_| InvalidFrameSnafu.build())?, + ); + let try_for = u32::from_be_bytes( + content[4..] + .try_into() + .map_err(|_| InvalidFrameSnafu.build())?, ); - let reconnect_in = u32::from_be_bytes(content[..4].try_into()?); - let try_for = u32::from_be_bytes(content[4..].try_into()?); Self::Restarting { reconnect_in, try_for, } } _ => { - anyhow::bail!("invalid frame type: {:?}", frame_type); + return Err(InvalidFrameTypeSnafu { frame_type }.build()); } }; Ok(res) @@ -513,7 +585,7 @@ mod framing { impl Decoder for RelayCodec { type Item = Frame; - type Error = anyhow::Error; + type Error = RecvError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { // Need at least 5 bytes @@ -535,7 +607,7 @@ mod framing { }; if frame_len > MAX_FRAME_SIZE { - anyhow::bail!("Frame of length {} is too large.", frame_len); + return Err(FrameTooLargeSnafu { frame_len }.build()); } if src.len() < HEADER_LEN + frame_len { @@ -582,35 +654,45 @@ mod framing { /// Receives the next frame and matches the frame type. If the correct type is found returns the content, /// otherwise an error. #[cfg(any(test, feature = "server"))] -pub(crate) async fn recv_frame> + Unpin>( +pub(crate) async fn recv_frame> + Unpin>( frame_type: FrameType, mut stream: S, -) -> anyhow::Result { +) -> Result +where + RecvError: Into, +{ match stream.next().await { Some(Ok(frame)) => { - ensure!( - frame_type == frame.typ(), - "expected frame {}, found {}", - frame_type, - frame.typ() - ); + if frame_type != frame.typ() { + return Err(UnexpectedFrameSnafu { + got: frame.typ(), + expected: frame_type, + } + .build() + .into()); + } Ok(frame) } Some(Err(err)) => Err(err), - None => bail!("EOF: unexpected stream end, expected frame {}", frame_type), + None => Err(RecvError::from(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "expected frame".to_string(), + )) + .into()), } } #[cfg(test)] mod tests { use data_encoding::HEXLOWER; + use n0_snafu::{Result, ResultExt}; use tokio_util::codec::{FramedRead, FramedWrite}; use super::*; #[tokio::test] #[cfg(feature = "server")] - async fn test_basic_read_write() -> anyhow::Result<()> { + async fn test_basic_read_write() -> Result { let (reader, writer) = tokio::io::duplex(1024); let mut reader = FramedRead::new(reader, RelayCodec::test()); let mut writer = FramedWrite::new(writer, RelayCodec::test()); @@ -620,7 +702,7 @@ mod tests { problem: expect_buf.to_vec().into(), }; write_frame(&mut writer, expected_frame.clone(), None).await?; - writer.flush().await?; + writer.flush().await.context("flush")?; println!("{:?}", reader); let buf = recv_frame(FrameType::Health, &mut reader).await?; assert_eq!(expect_buf.len(), buf.len()); @@ -630,7 +712,7 @@ mod tests { } #[tokio::test] - async fn test_send_recv_client_key() -> anyhow::Result<()> { + async fn test_send_recv_client_key() -> Result { let (reader, writer) = tokio::io::duplex(1024); let mut reader = FramedRead::new(reader, RelayCodec::test()); let mut writer = @@ -649,12 +731,12 @@ mod tests { } #[test] - fn test_frame_snapshot() -> anyhow::Result<()> { + fn test_frame_snapshot() -> Result { let client_key = SecretKey::from_bytes(&[42u8; 32]); let client_info = ClientInfo { version: PROTOCOL_VERSION, }; - let message = postcard::to_stdvec(&client_info)?; + let message = postcard::to_stdvec(&client_info).context("encode")?; let signature = client_key.sign(&message); let frames = vec![ @@ -864,7 +946,7 @@ mod proptests { codec.encode(frame.clone(), &mut buf).unwrap(); inject_error(&mut buf); let decoded = codec.decode(&mut buf); - prop_assert!(decoded.is_err()); + prop_assert!(decoded.is_err(), "{:?}", decoded); } } } diff --git a/iroh-relay/src/protos/stun.rs b/iroh-relay/src/protos/stun.rs index ebc3accbcb5..67e4d6cb02a 100644 --- a/iroh-relay/src/protos/stun.rs +++ b/iroh-relay/src/protos/stun.rs @@ -2,6 +2,8 @@ use std::net::SocketAddr; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, Snafu}; use stun_rs::{ attributes::stun::{Fingerprint, XorMappedAddress}, DecoderContextBuilder, MessageDecoderBuilder, MessageEncoderBuilder, StunMessageBuilder, @@ -12,26 +14,33 @@ pub use stun_rs::{ }; /// Errors that can occur when handling a STUN packet. -#[derive(Debug, thiserror::Error)] -pub enum Error { +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum StunError { /// The STUN message could not be parsed or is otherwise invalid. - #[error("invalid message")] - InvalidMessage, + #[snafu(display("invalid message"))] + InvalidMessage {}, /// STUN request is not a binding request when it should be. - #[error("not binding")] - NotBinding, + #[snafu(display("not binding"))] + NotBinding {}, /// STUN packet is not a response when it should be. - #[error("not success response")] - NotSuccessResponse, + #[snafu(display("not success response"))] + NotSuccessResponse {}, /// STUN response has malformed attributes. - #[error("malformed attributes")] - MalformedAttrs, + #[snafu(display("malformed attributes"))] + MalformedAttrs {}, /// STUN request didn't end in fingerprint. - #[error("no fingerprint")] - NoFingerprint, + #[snafu(display("no fingerprint"))] + NoFingerprint {}, /// STUN request had bogus fingerprint. - #[error("invalid fingerprint")] - InvalidFingerprint, + #[snafu(display("invalid fingerprint"))] + InvalidFingerprint {}, } /// Generates a binding request STUN packet. @@ -75,16 +84,16 @@ pub fn is(b: &[u8]) -> bool { } /// Parses a STUN binding request. -pub fn parse_binding_request(b: &[u8]) -> Result { +pub fn parse_binding_request(b: &[u8]) -> Result { let ctx = DecoderContextBuilder::default() .with_validation() // ensure fingerprint is validated .build(); let decoder = MessageDecoderBuilder::default().with_context(ctx).build(); - let (msg, _) = decoder.decode(b).map_err(|_| Error::InvalidMessage)?; + let (msg, _) = decoder.decode(b).map_err(|_| InvalidMessageSnafu.build())?; let tx = *msg.transaction_id(); if msg.method() != methods::BINDING { - return Err(Error::NotBinding); + return Err(NotBindingSnafu.build()); } // TODO: Tailscale sets the software to tailscale, we should check if we want to do this too. @@ -95,7 +104,7 @@ pub fn parse_binding_request(b: &[u8]) -> Result { .map(|attr| !attr.is_fingerprint()) .unwrap_or_default() { - return Err(Error::NoFingerprint); + return Err(NoFingerprintSnafu.build()); } Ok(tx) @@ -103,13 +112,13 @@ pub fn parse_binding_request(b: &[u8]) -> Result { /// Parses a successful binding response STUN packet. /// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. -pub fn parse_response(b: &[u8]) -> Result<(TransactionId, SocketAddr), Error> { +pub fn parse_response(b: &[u8]) -> Result<(TransactionId, SocketAddr), StunError> { let decoder = MessageDecoder::default(); - let (msg, _) = decoder.decode(b).map_err(|_| Error::InvalidMessage)?; + let (msg, _) = decoder.decode(b).map_err(|_| InvalidMessageSnafu.build())?; let tx = *msg.transaction_id(); if msg.class() != MessageClass::SuccessResponse { - return Err(Error::NotSuccessResponse); + return Err(NotSuccessResponseSnafu.build()); } // Read through the attributes. @@ -144,7 +153,7 @@ pub fn parse_response(b: &[u8]) -> Result<(TransactionId, SocketAddr), Error> { return Ok((tx, addr)); } - Err(Error::MalformedAttrs) + Err(MalformedAttrsSnafu.build()) } #[cfg(test)] diff --git a/iroh-relay/src/quic.rs b/iroh-relay/src/quic.rs index 678d1e5910e..e5344faea8a 100644 --- a/iroh-relay/src/quic.rs +++ b/iroh-relay/src/quic.rs @@ -2,9 +2,14 @@ //! for QUIC address discovery. use std::{net::SocketAddr, sync::Arc}; -use anyhow::Result; use n0_future::time::Duration; -use quinn::{crypto::rustls::QuicClientConfig, VarInt}; +use nested_enum_utils::common_fields; +use quinn::{ + crypto::rustls::{NoInitialCipherSuite, QuicClientConfig}, + VarInt, +}; +use snafu::{Backtrace, Snafu}; +use tokio::sync::watch; /// ALPN for our quic addr discovery pub const ALPN_QUIC_ADDR_DISC: &[u8] = b"/iroh-qad/0"; @@ -15,7 +20,8 @@ pub const QUIC_ADDR_DISC_CLOSE_REASON: &[u8] = b"finished"; #[cfg(feature = "server")] pub(crate) mod server { - use quinn::{crypto::rustls::QuicServerConfig, ApplicationClose}; + use quinn::{crypto::rustls::QuicServerConfig, ApplicationClose, ConnectionError}; + use snafu::ResultExt; use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{debug, info, info_span, Instrument}; @@ -29,6 +35,34 @@ pub(crate) mod server { handle: AbortOnDropHandle<()>, } + /// Server spawn errors + #[allow(missing_docs)] + #[derive(Debug, Snafu)] + #[non_exhaustive] + pub enum QuicSpawnError { + #[snafu(transparent)] + NoInitialCipherSuite { + source: NoInitialCipherSuite, + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Unable to spawn a QUIC endpoint server"))] + EndpointServer { + source: std::io::Error, + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Unable to get the local address from the endpoint"))] + LocalAddr { + source: std::io::Error, + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + } + impl QuicServer { /// Returns a handle for this server. /// @@ -60,13 +94,13 @@ pub(crate) mod server { /// # Errors /// If the given `quic_config` contains a [`rustls::ServerConfig`] that cannot /// be converted to a [`QuicServerConfig`], usually because it does not support - /// TLS 1.3, this method will error. + /// TLS 1.3, a [`NoInitialCipherSuite`] will occur. /// /// # Panics /// If there is a panic during a connection, it will be propagated /// up here. Any other errors in a connection will be logged as a /// warning. - pub(crate) fn spawn(mut quic_config: QuicConfig) -> Result { + pub(crate) fn spawn(mut quic_config: QuicConfig) -> Result { quic_config.server_config.alpn_protocols = vec![crate::quic::ALPN_QUIC_ADDR_DISC.to_vec()]; let server_config = QuicServerConfig::try_from(quic_config.server_config)?; @@ -79,8 +113,9 @@ pub(crate) mod server { // enable sending quic address discovery frames .send_observed_address_reports(true); - let endpoint = quinn::Endpoint::server(server_config, quic_config.bind_addr)?; - let bind_addr = endpoint.local_addr()?; + let endpoint = quinn::Endpoint::server(server_config, quic_config.bind_addr) + .context(EndpointServerSnafu)?; + let bind_addr = endpoint.local_addr().context(LocalAddrSnafu)?; info!(?bind_addr, "QUIC server listening on"); @@ -145,12 +180,13 @@ pub(crate) mod server { /// Closes the underlying QUIC endpoint and the tasks running the /// QUIC connections. - pub async fn shutdown(mut self) -> Result<()> { + pub async fn shutdown(mut self) { self.cancel.cancel(); if !self.task_handle().is_finished() { - self.task_handle().await? + // only possible error is a `JoinError`, no errors about what might + // have happened during a connection are propagated. + _ = self.task_handle().await; } - Ok(()) } } @@ -170,11 +206,11 @@ pub(crate) mod server { } /// Handle the connection from the client. - async fn handle_connection(incoming: quinn::Incoming) -> Result<()> { + async fn handle_connection(incoming: quinn::Incoming) -> Result<(), ConnectionError> { let connection = match incoming.await { Ok(conn) => conn, Err(e) => { - return Err(e.into()); + return Err(e); } }; debug!("established"); @@ -186,11 +222,31 @@ pub(crate) mod server { { Ok(()) } - _ => Err(connection_err.into()), + _ => Err(connection_err), } } } +/// Quic client related errors. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum Error { + #[snafu(transparent)] + Connect { source: quinn::ConnectError }, + #[snafu(transparent)] + Connection { source: quinn::ConnectionError }, + #[snafu(transparent)] + WatchRecv { source: watch::error::RecvError }, + #[snafu(transparent)] + NoIntitialCipherSuite { source: NoInitialCipherSuite }, +} + /// Handles the client side of QUIC address discovery. #[derive(Debug)] pub struct QuicClient { @@ -203,7 +259,10 @@ pub struct QuicClient { impl QuicClient { /// Create a new QuicClient to handle the client side of QUIC /// address discovery. - pub fn new(ep: quinn::Endpoint, mut client_config: rustls::ClientConfig) -> Result { + pub fn new( + ep: quinn::Endpoint, + mut client_config: rustls::ClientConfig, + ) -> Result { // add QAD alpn client_config.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.into()]; // go from rustls client config to rustls QUIC specific client config to @@ -240,7 +299,7 @@ impl QuicClient { &self, server_addr: SocketAddr, host: &str, - ) -> Result<(SocketAddr, std::time::Duration)> { + ) -> Result<(SocketAddr, std::time::Duration), Error> { let connecting = self .ep .connect_with(self.client_config.clone(), server_addr, host); @@ -251,7 +310,7 @@ impl QuicClient { // tokio::select! { // _ = cancel.cancelled() => { // conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON); - // anyhow::bail!("QUIC address discovery canceled early"); + // bail!("QUIC address discovery canceled early"); // }, // res = external_addresses.wait_for(|addr| addr.is_some()) => { // let addr = res?.expect("checked"); @@ -284,8 +343,8 @@ impl QuicClient { mod tests { use std::net::Ipv4Addr; - use anyhow::Context; use n0_future::{task::AbortOnDropHandle, time}; + use n0_snafu::{Error, Result, ResultExt}; use quinn::crypto::rustls::QuicServerConfig; use tokio::time::Instant; use tracing::{debug, info, info_span, Instrument}; @@ -299,8 +358,8 @@ mod tests { #[tokio::test] #[traced_test] - async fn quic_endpoint_basic() -> anyhow::Result<()> { - let host: Ipv4Addr = "127.0.0.1".parse()?; + async fn quic_endpoint_basic() -> Result { + let host: Ipv4Addr = "127.0.0.1".parse().unwrap(); // create a server config with self signed certificates let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config(); let bind_addr = SocketAddr::new(host.into(), 0); @@ -310,8 +369,9 @@ mod tests { })?; // create a client-side endpoint - let client_endpoint = quinn::Endpoint::client(SocketAddr::new(host.into(), 0))?; - let client_addr = client_endpoint.local_addr()?; + let client_endpoint = + quinn::Endpoint::client(SocketAddr::new(host.into(), 0)).context("client")?; + let client_addr = client_endpoint.local_addr().context("local addr")?; // create the client configuration used for the client endpoint when they // initiate a connection with the server @@ -325,7 +385,7 @@ mod tests { // wait until the endpoint delivers the closing message to the server client_endpoint.wait_idle().await; // shut down the quic server - quic_server.shutdown().await?; + quic_server.shutdown().await; assert_eq!(client_addr, addr); Ok(()) @@ -333,15 +393,18 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_qad_client_closes_unresponsive_fast() -> anyhow::Result<()> { + async fn test_qad_client_closes_unresponsive_fast() -> Result { // create a client-side endpoint let client_endpoint = - quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))?; + quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)) + .context("client")?; // create an socket that does not respond. let server_socket = - tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)).await?; - let server_addr = server_socket.local_addr()?; + tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)) + .await + .context("bind")?; + let server_addr = server_socket.local_addr().context("local addr")?; // create the client configuration used for the client endpoint when they // initiate a connection with the server @@ -369,7 +432,7 @@ mod tests { println!("Closed in {time:?}"); assert!(Duration::from_millis(900) < time); - assert!(time < Duration::from_millis(1100)); + assert!(time < Duration::from_millis(1500)); // give it some lee-way. Apparently github actions ubuntu runners can be slow? Ok(()) } @@ -382,24 +445,28 @@ mod tests { /// all packets on the server-side for 2 seconds. #[tokio::test] #[traced_test] - async fn test_qad_connect_delayed() -> anyhow::Result<()> { + async fn test_qad_connect_delayed() -> Result { // Create a socket for our QAD server. We need the socket separately because we // need to pop off messages before we attach it to the Quinn Endpoint. - let socket = - tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)).await?; - let server_addr = socket.local_addr()?; + let socket = tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)) + .await + .context("bind")?; + let server_addr = socket.local_addr().context("local addr")?; info!(addr = ?server_addr, "server socket bound"); // Create a QAD server with a self-signed cert, all manually. - let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])?; + let cert = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).context("self signed")?; let key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); let mut server_crypto = rustls::ServerConfig::builder() .with_no_client_auth() - .with_single_cert(vec![cert.cert.into()], key.into())?; + .with_single_cert(vec![cert.cert.into()], key.into()) + .context("tls")?; server_crypto.key_log = Arc::new(rustls::KeyLogFile::new()); server_crypto.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.to_vec()]; - let mut server_config = - quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(server_crypto)?)); + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(server_crypto).context("config")?, + )); let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); transport_config.send_observed_address_reports(true); @@ -420,16 +487,17 @@ mod tests { let server = quinn::Endpoint::new( Default::default(), Some(server_config), - socket.into_std()?, + socket.into_std().unwrap(), Arc::new(quinn::TokioRuntime), - )?; + ) + .context("endpoint new")?; info!("accepting conn"); - let incoming = server.accept().await.context("no conn")?; + let incoming = server.accept().await.expect("missing conn"); info!("incoming!"); - let conn = incoming.await?; + let conn = incoming.await.context("incoming")?; conn.closed().await; server.wait_idle().await; - Ok::<(), anyhow::Error>(()) + Ok::<_, Error>(()) } .instrument(info_span!("server")), ); @@ -437,7 +505,8 @@ mod tests { info!("starting client"); let client_endpoint = - quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))?; + quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)) + .context("client")?; // create the client configuration used for the client endpoint when they // initiate a connection with the server @@ -450,12 +519,16 @@ mod tests { Duration::from_secs(10), quic_client.get_addr_and_latency(server_addr, "localhost"), ) - .await??; + .await + .context("timeout")??; let duration = start.elapsed(); info!(?duration, ?addr, ?latency, "QAD succeeded"); assert!(duration >= Duration::from_secs(1)); - time::timeout(Duration::from_secs(10), server_task).await???; + time::timeout(Duration::from_secs(10), server_task) + .await + .context("timeout")? + .context("server task")??; Ok(()) } diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index ca8d4935cbc..73bd29d712e 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -18,19 +18,21 @@ use std::{fmt, future::Future, net::SocketAddr, num::NonZeroU32, pin::Pin, sync::Arc}; -use anyhow::{anyhow, bail, Context, Result}; use derive_more::Debug; use http::{ - response::Builder as ResponseBuilder, HeaderMap, Method, Request, Response, StatusCode, + header::InvalidHeaderValue, response::Builder as ResponseBuilder, HeaderMap, Method, Request, + Response, StatusCode, }; use hyper::body::Incoming; use iroh_base::NodeId; #[cfg(feature = "test-utils")] use iroh_base::RelayUrl; use n0_future::{future::Boxed, StreamExt}; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; use tokio::{ net::{TcpListener, UdpSocket}, - task::JoinSet, + task::{JoinError, JoinSet}, }; use tokio_util::task::AbortOnDropHandle; use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument}; @@ -39,7 +41,7 @@ use crate::{ defaults::DEFAULT_KEY_CACHE_CAPACITY, http::RELAY_PROBE_PATH, protos, - quic::server::{QuicServer, ServerHandle as QuicServerHandle}, + quic::server::{QuicServer, QuicSpawnError, ServerHandle as QuicServerHandle}, }; mod client; @@ -262,7 +264,7 @@ pub struct Server { /// Handle to the quic server. quic_handle: Option, /// The main task running the server. - supervisor: AbortOnDropHandle>, + supervisor: AbortOnDropHandle>, /// The certificate for the server. /// /// If the server has manual certificates configured the certificate chain will be @@ -271,9 +273,57 @@ pub struct Server { metrics: RelayMetrics, } +/// Server spawn errors +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum SpawnError { + #[snafu(display("Unable to get local address"))] + LocalAddr { source: std::io::Error }, + #[snafu(display("Failed to bind STUN listener"))] + UdpSocketBind { source: std::io::Error }, + #[snafu(display("Failed to bind STUN listener"))] + QuicSpawn { source: QuicSpawnError }, + #[snafu(display("Failed to parse TLS header"))] + TlsHeaderParse { source: InvalidHeaderValue }, + #[snafu(display("Failed to bind TcpListener"))] + BindTlsListener { source: std::io::Error }, + #[snafu(display("No local address"))] + NoLocalAddr { source: std::io::Error }, + #[snafu(display("Failed to bind server socket to {addr}"))] + BindTcpListener { addr: SocketAddr }, +} + +/// Server task errors +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum SupervisorError { + #[snafu(display("Error starting metrics server"))] + Metrics { source: std::io::Error }, + #[snafu(display("Acme event stream finished"))] + AcmeEventStreamFinished {}, + #[snafu(transparent)] + JoinError { source: JoinError }, + #[snafu(display("No relay services are enabled"))] + NoRelayServicesEnabled {}, + #[snafu(display("Task cancelled"))] + TaskCancelled {}, +} + impl Server { /// Starts the server. - pub async fn spawn(config: ServerConfig) -> Result + pub async fn spawn(config: ServerConfig) -> Result where EC: fmt::Debug + 'static, EA: fmt::Debug + 'static, @@ -289,8 +339,9 @@ impl Server { registry.register_all(&metrics); tasks.spawn( async move { - iroh_metrics::service::start_metrics_server(addr, Arc::new(registry)).await?; - anyhow::Ok(()) + iroh_metrics::service::start_metrics_server(addr, Arc::new(registry)) + .await + .context(MetricsSnafu) } .instrument(info_span!("metrics-server")), ); @@ -302,15 +353,19 @@ impl Server { debug!("Starting STUN server"); match UdpSocket::bind(stun.bind_addr).await { Ok(sock) => { - let addr = sock.local_addr()?; + let addr = sock.local_addr().context(LocalAddrSnafu)?; info!("STUN server listening on {addr}"); + let stun_metrics = metrics.stun.clone(); tasks.spawn( - server_stun_listener(sock, metrics.stun.clone()) - .instrument(info_span!("stun-server", %addr)), + async move { + server_stun_listener(sock, stun_metrics).await; + Ok(()) + } + .instrument(info_span!("stun-server", %addr)), ); Some(addr) } - Err(err) => bail!("failed to bind STUN listener: {err:#?}"), + Err(err) => return Err(err).context(UdpSocketBindSnafu), } } None => None, @@ -328,7 +383,7 @@ impl Server { let quic_server = match config.quic { Some(quic_config) => { debug!("Starting QUIC server {}", quic_config.bind_addr); - Some(QuicServer::spawn(quic_config)?) + Some(QuicServer::spawn(quic_config).context(QuicSpawnSnafu)?) } None => None, }; @@ -340,7 +395,7 @@ impl Server { debug!("Starting Relay server"); let mut headers = HeaderMap::new(); for (name, value) in TLS_HEADERS.iter() { - headers.insert(*name, value.parse()?); + headers.insert(*name, value.parse().context(TlsHeaderParseSnafu)?); } let relay_bind_addr = match relay_config.tls { Some(ref tls) => tls.https_bind_addr, @@ -375,7 +430,7 @@ impl Server { Err(err) => error!("error: {err:?}"), } } - Err(anyhow!("acme event stream finished")) + Err(AcmeEventStreamFinishedSnafu.build()) } .instrument(info_span!("acme")), ); @@ -401,11 +456,14 @@ impl Server { // these standalone. let http_listener = TcpListener::bind(&relay_config.http_bind_addr) .await - .context("failed to bind http")?; - let http_addr = http_listener.local_addr()?; + .context(BindTlsListenerSnafu)?; + let http_addr = http_listener.local_addr().context(NoLocalAddrSnafu)?; tasks.spawn( - run_captive_portal_service(http_listener) - .instrument(info_span!("http-service", addr = %http_addr)), + async move { + run_captive_portal_service(http_listener).await; + Ok(()) + } + .instrument(info_span!("http-service", addr = %http_addr)), ); Some(http_addr) } @@ -447,7 +505,7 @@ impl Server { /// Requests graceful shutdown. /// /// Returns once all server tasks have stopped. - pub async fn shutdown(self) -> Result<()> { + pub async fn shutdown(self) -> Result<(), SupervisorError> { // Only the Relay server and QUIC server need shutting down, the supervisor will abort the tasks in // the JoinSet when the server terminates. if let Some(handle) = self.relay_handle { @@ -463,7 +521,7 @@ impl Server { /// /// This allows waiting for the server's supervisor task to finish. Can be useful in /// case there is an error in the server before it is shut down. - pub fn task_handle(&mut self) -> &mut AbortOnDropHandle> { + pub fn task_handle(&mut self) -> &mut AbortOnDropHandle> { &mut self.supervisor } @@ -528,10 +586,10 @@ impl Server { /// The supervisor finishes once all tasks are finished. #[instrument(skip_all)] async fn relay_supervisor( - mut tasks: JoinSet>, + mut tasks: JoinSet>, mut relay_http_server: Option, mut quic_server: Option, -) -> Result<()> { +) -> Result<(), SupervisorError> { let quic_enabled = quic_server.is_some(); let mut quic_fut = match quic_server { Some(ref mut server) => n0_future::Either::Left(server.task_handle()), @@ -545,9 +603,9 @@ async fn relay_supervisor( let res = tokio::select! { biased; Some(ret) = tasks.join_next() => ret, - ret = &mut quic_fut, if quic_enabled => ret.map(anyhow::Ok), - ret = &mut relay_fut, if relay_enabled => ret.map(anyhow::Ok), - else => Ok(Err(anyhow!("No relay services are enabled."))), + ret = &mut quic_fut, if quic_enabled => ret.map(Ok), + ret = &mut relay_fut, if relay_enabled => ret.map(Ok), + else => Ok(Err(NoRelayServicesEnabledSnafu.build())), }; let ret = match res { Ok(Ok(())) => { @@ -556,7 +614,7 @@ async fn relay_supervisor( } Ok(Err(err)) => { error!(%err, "Task failed"); - Err(err.context("task failed")) + Err(err) } Err(err) => { if let Ok(panic) = err.try_into_panic() { @@ -564,7 +622,7 @@ async fn relay_supervisor( std::panic::resume_unwind(panic); } debug!("Task cancelled"); - Err(anyhow!("task cancelled")) + Err(TaskCancelledSnafu.build()) } }; @@ -576,7 +634,7 @@ async fn relay_supervisor( // Ensure the QUIC server is closed if let Some(server) = quic_server { - server.shutdown().await?; + server.shutdown().await; } // Stop all remaining tasks @@ -588,7 +646,7 @@ async fn relay_supervisor( /// Runs a STUN server. /// /// When the future is dropped, the server stops. -async fn server_stun_listener(sock: UdpSocket, metrics: Arc) -> Result<()> { +async fn server_stun_listener(sock: UdpSocket, metrics: Arc) { info!(addr = ?sock.local_addr().ok(), "running STUN server"); let sock = Arc::new(sock); let mut buffer = vec![0u8; 64 << 10]; @@ -740,7 +798,7 @@ fn is_challenge_char(c: char) -> bool { } /// This is a future that never returns, drop it to cancel/abort. -async fn run_captive_portal_service(http_listener: TcpListener) -> Result<()> { +async fn run_captive_portal_service(http_listener: TcpListener) { info!("serving"); // If this future is cancelled, this is dropped and all tasks are aborted. @@ -819,20 +877,26 @@ mod tests { use std::{net::Ipv4Addr, time::Duration}; use bytes::Bytes; - use http::header::UPGRADE; + use http::{header::UPGRADE, StatusCode}; use iroh_base::{NodeId, RelayUrl, SecretKey}; - use n0_future::{FutureExt, SinkExt}; - use testresult::TestResult; + use n0_future::{FutureExt, SinkExt, StreamExt}; + use n0_snafu::{Result, ResultExt}; + use tokio::net::UdpSocket; + use tracing::{info, instrument}; use tracing_test::traced_test; - use super::*; + use super::{ + Access, AccessConfig, RelayConfig, Server, ServerConfig, SpawnError, StunConfig, + NO_CONTENT_CHALLENGE_HEADER, NO_CONTENT_RESPONSE_HEADER, + }; use crate::{ client::{conn::ReceivedMessage, ClientBuilder, SendMessage}, dns::DnsResolver, http::{Protocol, HTTP_UPGRADE_PROTOCOL}, + protos, }; - async fn spawn_local_relay() -> Result { + async fn spawn_local_relay() -> std::result::Result { Server::spawn(ServerConfig::<(), ()> { relay: Some(RelayConfig::<(), ()> { http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), @@ -848,6 +912,7 @@ mod tests { .await } + #[instrument] async fn try_send_recv( client_a: &mut crate::client::Client, client_b: &mut crate::client::Client, @@ -863,7 +928,8 @@ mod tests { else { continue; }; - return res.context("stream finished")?; + let res = res.expect("stream finished")?; + return Ok(res); } panic!("failed to send and recv message"); } @@ -962,7 +1028,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_relay_clients_both_relay() -> TestResult<()> { + async fn test_relay_clients_both_relay() -> Result<()> { let server = spawn_local_relay().await.unwrap(); let relay_url = format!("http://{}", server.http_addr().unwrap()); let relay_url: RelayUrl = relay_url.parse().unwrap(); @@ -1014,7 +1080,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_relay_clients_both_websockets() -> TestResult<()> { + async fn test_relay_clients_both_websockets() -> Result<()> { let server = spawn_local_relay().await?; let relay_url = format!("http://{}", server.http_addr().unwrap()); @@ -1076,7 +1142,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_relay_clients_websocket_and_relay() -> TestResult<()> { + async fn test_relay_clients_websocket_and_relay() -> Result<()> { let server = spawn_local_relay().await.unwrap(); let relay_url = format!("http://{}", server.http_addr().unwrap()); @@ -1165,6 +1231,9 @@ mod tests { #[tokio::test] #[traced_test] async fn test_relay_access_control() -> Result<()> { + let current_span = tracing::info_span!("this is a test"); + let _guard = current_span.enter(); + let a_secret_key = SecretKey::generate(rand::thread_rng()); let a_key = a_secret_key.public(); @@ -1191,8 +1260,8 @@ mod tests { stun: None, metrics_addr: None, }) - .await - .unwrap(); + .await?; + let relay_url = format!("http://{}", server.http_addr().unwrap()); let relay_url: RelayUrl = relay_url.parse()?; @@ -1213,7 +1282,8 @@ mod tests { } } }) - .await?; + .await + .context("timeout")?; // test that another client has access @@ -1255,7 +1325,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_relay_clients_full() -> TestResult<()> { + async fn test_relay_clients_full() -> Result<()> { let server = spawn_local_relay().await.unwrap(); let relay_url = format!("http://{}", server.http_addr().unwrap()); let relay_url: RelayUrl = relay_url.parse().unwrap(); diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index d04eb58e1e3..6692b7aa10c 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -5,11 +5,12 @@ use std::{ time::Duration, }; -use anyhow::{bail, Context, Result}; use bytes::Bytes; use iroh_base::NodeId; use n0_future::{FutureExt, Sink, SinkExt, Stream, StreamExt}; +use nested_enum_utils::common_fields; use rand::Rng; +use snafu::{Backtrace, GenerateImplicitData, Snafu}; use time::{Date, OffsetDateTime}; use tokio::{ sync::mpsc::{self, error::TrySendError}, @@ -21,9 +22,14 @@ use tracing::{debug, error, instrument, trace, warn, Instrument}; use crate::{ protos::{ disco, - relay::{write_frame, Frame, PING_INTERVAL}, + relay::{write_frame, Frame, SendError as SendRelayError, PING_INTERVAL}, + }, + server::{ + clients::Clients, + metrics::Metrics, + streams::{RelayedStream, StreamError}, + ClientRateLimit, }, - server::{clients::Clients, metrics::Metrics, streams::RelayedStream, ClientRateLimit}, PingTracker, }; @@ -180,6 +186,96 @@ impl Client { } } +/// Handle frame error +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum HandleFrameError { + #[snafu(transparent)] + ForwardPacket { source: ForwardPacketError }, + #[snafu(transparent)] + Streams { source: StreamError }, + #[snafu(display("Stream terminated"))] + StreamTerminated { + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(transparent)] + Relay { source: SendRelayError }, + #[snafu(display("Server issue: {problem:?}"))] + Health { + problem: Bytes, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, +} + +/// Run error +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum RunError { + #[snafu(transparent)] + ForwardPacket { source: ForwardPacketError }, + #[snafu(display("Flush"))] + Flush { + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(transparent)] + HandleFrame { source: HandleFrameError }, + #[snafu(display("Server.disco_send_queue dropped"))] + DiscoSendQueuePacketDrop { + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Failed to send disco packet"))] + DiscoPacketSend { + source: SendRelayError, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Server.send_queue dropped"))] + SendQueuePacketDrop { + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Failed to send packet"))] + PacketSend { + source: SendRelayError, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Server.node_gone dropped"))] + NodeGoneDrop { + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("NodeGone write frame failed"))] + NodeGoneWriteFrame { + source: SendRelayError, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Keep alive write frame failed"))] + KeepAliveWriteFrame { + source: SendRelayError, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Tick flush"))] + TickFlush { + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, +} + /// Manages all the reads and writes to this client. It periodically sends a `KEEP_ALIVE` /// message to the client to keep the connection alive. /// @@ -243,7 +339,9 @@ impl Actor { self.metrics.disconnects.inc(); } - async fn run_inner(&mut self, done: CancellationToken) -> Result<()> { + async fn run_inner(&mut self, done: CancellationToken) -> Result<(), RunError> { + use snafu::ResultExt; + // Add some jitter to ping pong interactions, to avoid all pings being sent at the same time let next_interval = || { let random_secs = rand::rngs::OsRng.gen_range(1..=5); @@ -262,29 +360,29 @@ impl Actor { _ = done.cancelled() => { trace!("actor loop cancelled, exiting"); // final flush - self.stream.flush().await.context("flush")?; + self.stream.flush().await.map_err(|_| FlushSnafu.build())?; break; } maybe_frame = self.stream.next() => { - self.handle_frame(maybe_frame).await.context("handle read")?; + self.handle_frame(maybe_frame).await?; // reset the ping interval, we just received a message ping_interval.reset(); } // First priority, disco packets packet = self.disco_send_queue.recv() => { - let packet = packet.context("Server.disco_send_queue dropped")?; - self.send_disco_packet(packet).await.context("send packet")?; + let packet = packet.ok_or(DiscoSendQueuePacketDropSnafu.build())?; + self.send_disco_packet(packet).await.context(DiscoPacketSendSnafu)?; } // Second priority, sending regular packets packet = self.send_queue.recv() => { - let packet = packet.context("Server.send_queue dropped")?; - self.send_packet(packet).await.context("send packet")?; + let packet = packet.ok_or(SendQueuePacketDropSnafu.build())?; + self.send_packet(packet).await.context(PacketSendSnafu)?; } // Last priority, sending left nodes node_id = self.node_gone.recv() => { - let node_id = node_id.context("Server.node_gone dropped")?; + let node_id = node_id.ok_or(NodeGoneDropSnafu.build())?; trace!("node_id gone: {:?}", node_id); - self.write_frame(Frame::NodeGone { node_id }).await?; + self.write_frame(Frame::NodeGone { node_id }).await.context(NodeGoneWriteFrameSnafu)?; } _ = self.ping_tracker.timeout() => { trace!("pong timed out"); @@ -295,11 +393,14 @@ impl Actor { // new interval ping_interval.reset_after(next_interval()); let data = self.ping_tracker.new_ping(); - self.write_frame(Frame::Ping { data }).await?; + self.write_frame(Frame::Ping { data }).await.context(KeepAliveWriteFrameSnafu)?; } } - self.stream.flush().await.context("tick flush")?; + self.stream + .flush() + .await + .map_err(|_| TickFlushSnafu.build())?; } Ok(()) } @@ -307,7 +408,7 @@ impl Actor { /// Writes the given frame to the connection. /// /// Errors if the send does not happen within the `timeout` duration - async fn write_frame(&mut self, frame: Frame) -> Result<()> { + async fn write_frame(&mut self, frame: Frame) -> Result<(), SendRelayError> { write_frame(&mut self.stream, frame, Some(self.timeout)).await } @@ -315,7 +416,7 @@ impl Actor { /// /// Errors if the send does not happen within the `timeout` duration /// Does not flush. - async fn send_raw(&mut self, packet: Packet) -> Result<()> { + async fn send_raw(&mut self, packet: Packet) -> Result<(), SendRelayError> { let src_key = packet.src; let content = packet.data; @@ -326,7 +427,7 @@ impl Actor { .await } - async fn send_packet(&mut self, packet: Packet) -> Result<()> { + async fn send_packet(&mut self, packet: Packet) -> Result<(), SendRelayError> { trace!("send packet"); match self.send_raw(packet).await { Ok(()) => { @@ -340,7 +441,7 @@ impl Actor { } } - async fn send_disco_packet(&mut self, packet: Packet) -> Result<()> { + async fn send_disco_packet(&mut self, packet: Packet) -> Result<(), SendRelayError> { trace!("send disco packet"); match self.send_raw(packet).await { Ok(()) => { @@ -355,11 +456,14 @@ impl Actor { } /// Handles frame read results. - async fn handle_frame(&mut self, maybe_frame: Option>) -> Result<()> { + async fn handle_frame( + &mut self, + maybe_frame: Option>, + ) -> Result<(), HandleFrameError> { trace!(?maybe_frame, "handle incoming frame"); let frame = match maybe_frame { Some(frame) => frame?, - None => anyhow::bail!("stream terminated"), + None => return Err(StreamTerminatedSnafu.build()), }; match frame { @@ -381,9 +485,7 @@ impl Actor { Frame::Pong { data } => { self.ping_tracker.pong_received(data); } - Frame::Health { problem } => { - bail!("server issue: {:?}", problem); - } + Frame::Health { problem } => return Err(HealthSnafu { problem }.build()), _ => { self.metrics.unknown_frames.inc(); } @@ -417,16 +519,21 @@ pub(crate) enum SendError { Closed, } -#[derive(Debug, thiserror::Error)] -#[error("failed to forward {scope:?} packet: {reason:?}")] +#[derive(Debug, Snafu)] +#[snafu(display("failed to forward {scope:?} packet: {reason:?}"))] pub(crate) struct ForwardPacketError { scope: PacketScope, reason: SendError, + backtrace: Option, } impl ForwardPacketError { pub(crate) fn new(scope: PacketScope, reason: SendError) -> Self { - Self { scope, reason } + Self { + scope, + reason, + backtrace: GenerateImplicitData::generate(), + } } } @@ -453,7 +560,7 @@ enum State { /// Future which will complete when the item can be yielded. delay: Pin + Send + Sync>>, /// Item to yield when the `delay` future completes. - item: anyhow::Result, + item: Result, }, Ready, } @@ -497,7 +604,7 @@ impl RateLimitedRelayedStream { } impl Stream for RateLimitedRelayedStream { - type Item = anyhow::Result; + type Item = Result; #[instrument(name = "rate_limited_relayed_stream", skip_all)] fn poll_next( @@ -647,7 +754,7 @@ impl ClientCounter { mod tests { use bytes::Bytes; use iroh_base::SecretKey; - use testresult::TestResult; + use n0_snafu::{Result, ResultExt}; use tokio_util::codec::Framed; use tracing::info; use tracing_test::traced_test; @@ -660,7 +767,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_client_actor_basic() -> Result<()> { + async fn test_client_actor_basic() -> Result { let (send_queue_s, send_queue_r) = mpsc::channel(10); let (disco_send_queue_s, disco_send_queue_r) = mpsc::channel(10); let (peer_gone_s, peer_gone_r) = mpsc::channel(10); @@ -701,7 +808,7 @@ mod tests { src: node_id, data: Bytes::from(&data[..]), }; - send_queue_s.send(packet.clone()).await?; + send_queue_s.send(packet.clone()).await.context("send")?; let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?; assert_eq!( frame, @@ -713,7 +820,10 @@ mod tests { // send disco packet println!(" send disco packet"); - disco_send_queue_s.send(packet.clone()).await?; + disco_send_queue_s + .send(packet.clone()) + .await + .context("send")?; let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?; assert_eq!( frame, @@ -725,7 +835,7 @@ mod tests { // send peer_gone println!("send peer gone"); - peer_gone_s.send(node_id).await?; + peer_gone_s.send(node_id).await.context("send")?; let frame = recv_frame(FrameType::PeerGone, &mut io_rw).await?; assert_eq!(frame, Frame::NodeGone { node_id }); @@ -751,7 +861,8 @@ mod tests { dst_key: target, packet: Bytes::from_static(data), }) - .await?; + .await + .context("send")?; // send disco packet println!(" send disco packet"); @@ -764,21 +875,22 @@ mod tests { dst_key: target, packet: disco_data.clone().into(), }) - .await?; + .await + .context("send")?; done.cancel(); - handle.await?; + handle.await.context("join")?; Ok(()) } #[tokio::test] #[traced_test] - async fn test_rate_limit() -> TestResult { + async fn test_rate_limit() -> Result { const LIMIT: u32 = 50; const MAX_FRAMES: u32 = 100; // Rate limiter allowing LIMIT bytes/s - let quota = governor::Quota::per_second(NonZeroU32::try_from(LIMIT)?); + let quota = governor::Quota::per_second(NonZeroU32::try_from(LIMIT).unwrap()); let limiter = governor::RateLimiter::direct(quota); // Build the rate limited stream. @@ -802,8 +914,8 @@ mod tests { // Send a frame, it should arrive. info!("-- send packet"); - frame_writer.send(frame.clone()).await?; - frame_writer.flush().await?; + frame_writer.send(frame.clone()).await.context("send")?; + frame_writer.flush().await.context("flush")?; let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next()) .await .expect("timeout") @@ -813,8 +925,8 @@ mod tests { // Next frame does not arrive. info!("-- send packet"); - frame_writer.send(frame.clone()).await?; - frame_writer.flush().await?; + frame_writer.send(frame.clone()).await.context("send")?; + frame_writer.flush().await.context("flush")?; let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await; assert!(res.is_err(), "expecting a timeout"); info!("-- timeout happened"); diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 19b6cb8c842..1122d560527 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -9,7 +9,6 @@ use std::{ }, }; -use anyhow::Result; use bytes::Bytes; use dashmap::DashMap; use iroh_base::NodeId; @@ -195,6 +194,7 @@ mod tests { use bytes::Bytes; use iroh_base::SecretKey; + use n0_snafu::{Result, ResultExt}; use tokio::io::DuplexStream; use tokio_util::codec::{Framed, FramedRead}; @@ -222,7 +222,7 @@ mod tests { } #[tokio::test] - async fn test_clients() -> Result<()> { + async fn test_clients() -> Result { let a_key = SecretKey::generate(rand::thread_rng()).public(); let b_key = SecretKey::generate(rand::thread_rng()).public(); @@ -271,7 +271,8 @@ mod tests { tokio::time::sleep(Duration::from_millis(100)).await; } }) - .await?; + .await + .context("timeout")?; clients.shutdown().await; Ok(()) diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index ca4ea755edb..b42e63036db 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -2,7 +2,6 @@ use std::{ collections::HashMap, future::Future, net::SocketAddr, pin::Pin, sync::Arc, time::Duration, }; -use anyhow::{bail, ensure, Context as _, Result}; use bytes::Bytes; use derive_more::Debug; use http::{header::CONNECTION, response::Builder as ResponseBuilder}; @@ -13,13 +12,16 @@ use hyper::{ upgrade::Upgraded, HeaderMap, Method, Request, Response, StatusCode, }; -use n0_future::{FutureExt, SinkExt}; +use iroh_base::PublicKey; +use n0_future::{time::Elapsed, FutureExt, SinkExt}; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, ResultExt, Snafu}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls_acme::AcmeAcceptor; use tokio_util::{codec::Framed, sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument}; -use super::{clients::Clients, AccessConfig}; +use super::{clients::Clients, streams::StreamError, AccessConfig, SpawnError}; use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, http::{Protocol, LEGACY_RELAY_PATH, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, @@ -30,7 +32,7 @@ use crate::{ client::Config, metrics::Metrics, streams::{MaybeTlsStream, RelayedStream}, - ClientRateLimit, + BindTcpListenerSnafu, ClientRateLimit, NoLocalAddrSnafu, }, KeyCache, }; @@ -45,7 +47,7 @@ type HyperHandler = Box< + 'static, >; -/// WebSocket GUID needed for accepting websocket connections, see RFC 6455 (https://www.rfc-editor.org/rfc/rfc6455) section 1.3 +/// WebSocket GUID needed for accepting websocket connections, see RFC 6455 () section 1.3 const SEC_WEBSOCKET_ACCEPT_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; /// Derives the accept key for WebSocket handshake according to RFC 6455. @@ -69,12 +71,11 @@ fn body_full(content: impl Into) -> BytesBody { http_body_util::Full::new(content.into()) } -fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes)> { +#[allow(clippy::result_large_err)] +fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes), ConnectionHandlerError> { match upgraded.downcast::>() { Ok(parts) => Ok((parts.io.into_inner(), parts.read_buf)), - Err(_) => { - bail!("could not downcast the upgraded connection to MaybeTlsStream") - } + Err(_) => Err(DowncastUpgradeSnafu.build()), } } @@ -147,6 +148,113 @@ pub(super) struct TlsConfig { pub(super) acceptor: TlsAcceptor, } +/// Errors when attempting to upgrade and +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ServeConnectionError { + #[snafu(display("TLS[acme] handshake"))] + Handshake { + source: std::io::Error, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("TLS[acme] serve connection"))] + ServeConnection { + source: hyper::Error, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("TLS[manual] timeout"))] + Timeout { + source: Elapsed, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("TLS[manual] accept"))] + ManualAccept { + source: std::io::Error, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("TLS[acme] accept"))] + LetsEncryptAccept { + source: std::io::Error, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("HTTPS connection"))] + Https { + source: hyper::Error, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("HTTP connection"))] + Http { + source: hyper::Error, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, +} + +/// Server accept errors. +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum AcceptError { + #[snafu(display("Unable to receive client information"))] + RecvClientKey { + #[allow(clippy::result_large_err)] + source: StreamError, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Unexpected client version {version}, expected {expected_version}"))] + UnexpectedClientVersion { + version: usize, + expected_version: usize, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(display("Client not authenticated: {key:?}"))] + ClientNotAuthenticated { + key: PublicKey, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, +} + +/// Server connection errors, includes errors that can happen on `accept`. +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ConnectionHandlerError { + #[snafu(transparent)] + Accept { source: AcceptError }, + #[snafu(display("Could not downcast the upgraded connection to MaybeTlsStream"))] + DowncastUpgrade { + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(display("Cannot deal with buffered data yet: {buf:?}"))] + BufferNotEmpty { + buf: Bytes, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, +} + /// Builder for the Relay HTTP Server. /// /// Defaults to handling relay requests on the "/relay" (and "/derp" for backwards compatibility) endpoint. @@ -246,7 +354,9 @@ impl ServerBuilder { } /// Builds and spawns an HTTP(S) Relay Server. - pub(super) async fn spawn(self) -> Result { + pub(super) async fn spawn(self) -> Result { + use snafu::ResultExt; + let cancel_token = CancellationToken::new(); let service = RelayService::new( @@ -265,9 +375,9 @@ impl ServerBuilder { let listener = TcpListener::bind(&addr) .await - .with_context(|| format!("failed to bind server socket to {addr}"))?; + .map_err(|_| BindTcpListenerSnafu { addr }.build())?; - let addr = listener.local_addr()?; + let addr = listener.local_addr().context(NoLocalAddrSnafu)?; let http_str = tls_config.as_ref().map_or("HTTP/WS", |_| "HTTPS/WSS"); info!("[{http_str}] relay: serving on {addr}"); @@ -500,16 +610,19 @@ impl Inner { /// This handler runs while doing the connection upgrade handshake. Once the connection /// is upgraded it sends the stream to the relay server which takes it over. After /// having sent off the connection this handler returns. - async fn relay_connection_handler(&self, protocol: Protocol, upgraded: Upgraded) -> Result<()> { + async fn relay_connection_handler( + &self, + protocol: Protocol, + upgraded: Upgraded, + ) -> Result<(), ConnectionHandlerError> { debug!(?protocol, "relay_connection upgraded"); let (io, read_buf) = downcast_upgrade(upgraded)?; - ensure!( - read_buf.is_empty(), - "can not deal with buffered data yet: {:?}", - read_buf - ); + if !read_buf.is_empty() { + return Err(BufferNotEmptySnafu { buf: read_buf }.build()); + } - self.accept(protocol, io).await + self.accept(protocol, io).await?; + Ok(()) } /// Adds a new connection to the server and serves it. @@ -522,7 +635,9 @@ impl Inner { /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`AsyncWrite`]: tokio::io::AsyncWrite - async fn accept(&self, protocol: Protocol, io: MaybeTlsStream) -> Result<()> { + async fn accept(&self, protocol: Protocol, io: MaybeTlsStream) -> Result<(), AcceptError> { + use snafu::ResultExt; + trace!(?protocol, "accept: start"); let mut io = match protocol { Protocol::Relay => { @@ -541,9 +656,7 @@ impl Inner { } }; trace!("accept: recv client key"); - let (client_key, info) = recv_client_key(&mut io) - .await - .context("unable to receive client information")?; + let (client_key, info) = recv_client_key(&mut io).await.context(RecvClientKeySnafu)?; trace!("accept: checking access: {:?}", self.access); if !self.access.is_allowed(client_key).await { @@ -553,15 +666,15 @@ impl Inner { .await?; io.flush().await?; - bail!("client is not authenticated: {}", client_key); + return Err(ClientNotAuthenticatedSnafu { key: client_key }.build()); } if info.version != PROTOCOL_VERSION { - bail!( - "unexpected client version {}, expected {}", - info.version, - PROTOCOL_VERSION - ); + return Err(UnexpectedClientVersionSnafu { + version: info.version, + expected_version: PROTOCOL_VERSION, + } + .build()); } trace!("accept: build client conn"); @@ -631,14 +744,27 @@ impl RelayService { } None => { debug!("HTTP: serve connection"); - self.serve_connection(MaybeTlsStream::Plain(stream)).await + self.serve_connection(MaybeTlsStream::Plain(stream)) + .await + .context(HttpSnafu) } }; match res { Ok(()) => {} - Err(error) => match error.downcast_ref::() { - Some(io_error) if io_error.kind() == std::io::ErrorKind::UnexpectedEof => { - debug!(reason=?error, "peer disconnected"); + Err(error) => match error { + ServeConnectionError::ManualAccept { source, .. } + | ServeConnectionError::LetsEncryptAccept { source, .. } + if source.kind() == std::io::ErrorKind::UnexpectedEof => + { + debug!(reason=?source, "peer disconnected"); + } + // From hyper: + // `hyper::Error::IncompleteMessage` is hyper's equivalent of UnexpectedEof + ServeConnectionError::Https { source, .. } + | ServeConnectionError::Http { source, .. } + if source.is_incomplete_message() => + { + debug!(reason=?source, "peer disconnected"); } _ => { error!(?error, "failed to handle connection"); @@ -648,49 +774,54 @@ impl RelayService { } /// Serve the tls connection - async fn tls_serve_connection(self, stream: TcpStream, tls_config: TlsConfig) -> Result<()> { + async fn tls_serve_connection( + self, + stream: TcpStream, + tls_config: TlsConfig, + ) -> Result<(), ServeConnectionError> { let TlsConfig { acceptor, config } = tls_config; match acceptor { - TlsAcceptor::LetsEncrypt(a) => match a.accept(stream).await? { - None => { - info!("TLS[acme]: received TLS-ALPN-01 validation request"); - } - Some(start_handshake) => { - debug!("TLS[acme]: start handshake"); - let tls_stream = start_handshake - .into_stream(config) - .await - .context("TLS[acme] handshake")?; - self.serve_connection(MaybeTlsStream::Tls(tls_stream)) - .await - .context("TLS[acme] serve connection")?; + TlsAcceptor::LetsEncrypt(a) => { + match a.accept(stream).await.context(LetsEncryptAcceptSnafu)? { + None => { + info!("TLS[acme]: received TLS-ALPN-01 validation request"); + } + Some(start_handshake) => { + debug!("TLS[acme]: start handshake"); + let tls_stream = start_handshake + .into_stream(config) + .await + .context(HandshakeSnafu)?; + self.serve_connection(MaybeTlsStream::Tls(tls_stream)) + .await + .context(HttpsSnafu)?; + } } - }, + } TlsAcceptor::Manual(a) => { debug!("TLS[manual]: accept"); let tls_stream = tokio::time::timeout(Duration::from_secs(30), a.accept(stream)) .await - .context("TLS[manual] timeout")? - .context("TLS[manual] accept")?; + .context(TimeoutSnafu)? + .context(ManualAcceptSnafu)?; self.serve_connection(MaybeTlsStream::Tls(tls_stream)) .await - .context("TLS[manual] serve connection")?; + .context(ServeConnectionSnafu)?; } } Ok(()) } /// Wrapper for the actual http connection (with upgrades) - async fn serve_connection(self, io: I) -> Result<()> + async fn serve_connection(self, io: I) -> Result<(), hyper::Error> where I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + 'static, { hyper::server::conn::http1::Builder::new() .serve_connection(hyper_util::rt::TokioIo::new(io), self) .with_upgrades() - .await?; - Ok(()) + .await } } @@ -725,11 +856,12 @@ impl std::ops::DerefMut for Handlers { mod tests { use std::sync::Arc; - use anyhow::Result; use bytes::Bytes; use iroh_base::{PublicKey, SecretKey}; use n0_future::{SinkExt, StreamExt}; + use n0_snafu::{Result, ResultExt}; use reqwest::Url; + use snafu::whatever; use tracing::info; use tracing_test::traced_test; @@ -738,7 +870,7 @@ mod tests { client::{ conn::{Conn, ReceivedMessage, SendMessage}, streams::MaybeTlsStreamChained, - Client, ClientBuilder, + Client, ClientBuilder, ConnectError, }, dns::DnsResolver, }; @@ -769,7 +901,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_http_clients_and_server() -> Result<()> { + async fn test_http_clients_and_server() -> Result { let a_key = SecretKey::generate(rand::thread_rng()); let b_key = SecretKey::generate(rand::thread_rng()); @@ -782,13 +914,12 @@ mod tests { // get dial info let port = addr.port(); - let addr = { - if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() { - ipv4_addr - } else { - anyhow::bail!("cannot get ipv4 addr from socket addr {addr:?}"); - } + let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() { + ipv4_addr + } else { + whatever!("cannot get ipv4 addr from socket addr {addr:?}"); }; + info!("addr: {addr}:{port}"); let relay_addr: Url = format!("http://{addr}:{port}").parse().unwrap(); @@ -800,12 +931,12 @@ mod tests { info!("ping a"); client_a.send(SendMessage::Ping([1u8; 8])).await?; - let pong = client_a.next().await.context("eos")??; + let pong = client_a.next().await.expect("eos")?; assert!(matches!(pong, ReceivedMessage::Pong(_))); info!("ping b"); client_b.send(SendMessage::Ping([2u8; 8])).await?; - let pong = client_b.next().await.context("eos")??; + let pong = client_b.next().await.expect("eos")?; assert!(matches!(pong, ReceivedMessage::Pong(_))); info!("sending message from a to b"); @@ -838,7 +969,10 @@ mod tests { Ok(()) } - async fn create_test_client(key: SecretKey, server_url: Url) -> Result<(PublicKey, Client)> { + async fn create_test_client( + key: SecretKey, + server_url: Url, + ) -> Result<(PublicKey, Client), ConnectError> { let public_key = key.public(); let client = ClientBuilder::new(server_url, key, DnsResolver::new()).insecure_skip_cert_verify(true); @@ -847,7 +981,9 @@ mod tests { Ok((public_key, client)) } - fn process_msg(msg: Option>) -> Option<(PublicKey, Bytes)> { + fn process_msg( + msg: Option>, + ) -> Option<(PublicKey, Bytes)> { match msg { Some(Err(e)) => { info!("client `recv` error {e}"); @@ -874,7 +1010,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_https_clients_and_server() -> Result<()> { + async fn test_https_clients_and_server() -> Result { let a_key = SecretKey::generate(rand::thread_rng()); let b_key = SecretKey::generate(rand::thread_rng()); @@ -891,13 +1027,12 @@ mod tests { // get dial info let port = addr.port(); - let addr = { - if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() { - ipv4_addr - } else { - anyhow::bail!("cannot get ipv4 addr from socket addr {addr:?}"); - } + let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() { + ipv4_addr + } else { + whatever!("cannot get ipv4 addr from socket addr {addr:?}"); }; + info!("Relay listening on: {addr}:{port}"); let url: Url = format!("https://localhost:{port}").parse().unwrap(); @@ -910,12 +1045,12 @@ mod tests { info!("ping a"); client_a.send(SendMessage::Ping([1u8; 8])).await?; - let pong = client_a.next().await.context("eos")??; + let pong = client_a.next().await.expect("eos")?; assert!(matches!(pong, ReceivedMessage::Pong(_))); info!("ping b"); client_b.send(SendMessage::Ping([2u8; 8])).await?; - let pong = client_b.next().await.context("eos")??; + let pong = client_b.next().await.expect("eos")?; assert!(matches!(pong, ReceivedMessage::Pong(_))); info!("sending message from a to b"); @@ -944,7 +1079,7 @@ mod tests { client_a.close().await?; client_b.close().await?; server.shutdown(); - server.task_handle().await?; + server.task_handle().await.context("join")?; Ok(()) } @@ -957,7 +1092,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_server_basic() -> Result<()> { + async fn test_server_basic() -> Result { info!("Create the server."); let service = RelayService::new( Default::default(), @@ -978,7 +1113,7 @@ mod tests { .await }); let mut client_a = make_test_client(client_a, &key_a).await?; - handler_task.await??; + handler_task.await.context("join")??; info!("Create client B and connect it to the server."); let key_b = SecretKey::generate(rand::thread_rng()); @@ -990,14 +1125,14 @@ mod tests { .await }); let mut client_b = make_test_client(client_b, &key_b).await?; - handler_task.await??; + handler_task.await.context("join")??; info!("Send message from A to B."); let msg = Bytes::from_static(b"hello client b!!"); client_a .send(SendMessage::SendPacket(public_key_b, msg.clone())) .await?; - match client_b.next().await.context("eos")?? { + match client_b.next().await.unwrap()? { ReceivedMessage::ReceivedPacket { remote_node_id, data, @@ -1006,7 +1141,7 @@ mod tests { assert_eq!(&msg[..], data); } msg => { - anyhow::bail!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedPacket msg, got {msg:?}"); } } @@ -1015,7 +1150,7 @@ mod tests { client_b .send(SendMessage::SendPacket(public_key_a, msg.clone())) .await?; - match client_a.next().await.context("eos")?? { + match client_a.next().await.unwrap()? { ReceivedMessage::ReceivedPacket { remote_node_id, data, @@ -1024,7 +1159,7 @@ mod tests { assert_eq!(&msg[..], data); } msg => { - anyhow::bail!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedPacket msg, got {msg:?}"); } } @@ -1045,7 +1180,7 @@ mod tests { } #[tokio::test] - async fn test_server_replace_client() -> Result<()> { + async fn test_server_replace_client() -> Result { info!("Create the server."); let service = RelayService::new( Default::default(), @@ -1066,7 +1201,7 @@ mod tests { .await }); let mut client_a = make_test_client(client_a, &key_a).await?; - handler_task.await??; + handler_task.await.context("join")??; info!("Create client B and connect it to the server."); let key_b = SecretKey::generate(rand::thread_rng()); @@ -1078,14 +1213,14 @@ mod tests { .await }); let mut client_b = make_test_client(client_b, &key_b).await?; - handler_task.await??; + handler_task.await.context("join")??; info!("Send message from A to B."); let msg = Bytes::from_static(b"hello client b!!"); client_a .send(SendMessage::SendPacket(public_key_b, msg.clone())) .await?; - match client_b.next().await.context("eos")?? { + match client_b.next().await.expect("eos")? { ReceivedMessage::ReceivedPacket { remote_node_id, data, @@ -1094,7 +1229,7 @@ mod tests { assert_eq!(&msg[..], data); } msg => { - anyhow::bail!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedPacket msg, got {msg:?}"); } } @@ -1103,7 +1238,7 @@ mod tests { client_b .send(SendMessage::SendPacket(public_key_a, msg.clone())) .await?; - match client_a.next().await.context("eos")?? { + match client_a.next().await.expect("eos")? { ReceivedMessage::ReceivedPacket { remote_node_id, data, @@ -1112,7 +1247,7 @@ mod tests { assert_eq!(&msg[..], data); } msg => { - anyhow::bail!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedPacket msg, got {msg:?}"); } } @@ -1124,7 +1259,7 @@ mod tests { .await }); let mut new_client_b = make_test_client(new_client_b, &key_b).await?; - handler_task.await??; + handler_task.await.context("join")??; // assert!(client_b.recv().await.is_err()); @@ -1133,7 +1268,7 @@ mod tests { client_a .send(SendMessage::SendPacket(public_key_b, msg.clone())) .await?; - match new_client_b.next().await.context("eos")?? { + match new_client_b.next().await.expect("eos")? { ReceivedMessage::ReceivedPacket { remote_node_id, data, @@ -1142,7 +1277,7 @@ mod tests { assert_eq!(&msg[..], data); } msg => { - anyhow::bail!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedPacket msg, got {msg:?}"); } } @@ -1151,7 +1286,7 @@ mod tests { new_client_b .send(SendMessage::SendPacket(public_key_a, msg.clone())) .await?; - match client_a.next().await.context("eos")?? { + match client_a.next().await.expect("eos")? { ReceivedMessage::ReceivedPacket { remote_node_id, data, @@ -1160,7 +1295,7 @@ mod tests { assert_eq!(&msg[..], data); } msg => { - anyhow::bail!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedPacket msg, got {msg:?}"); } } diff --git a/iroh-relay/src/server/resolver.rs b/iroh-relay/src/server/resolver.rs index 415bf830031..1262e4eea70 100644 --- a/iroh-relay/src/server/resolver.rs +++ b/iroh-relay/src/server/resolver.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use anyhow::{anyhow, Result}; use n0_future::{ task::{self, AbortOnDropHandle}, time::{self, Duration}, @@ -31,10 +30,8 @@ where Loader: Send + reloadable_state::core::Loader + 'static, { /// Perform the initial load and construct the [`ReloadingResolver`]. - pub async fn init(loader: Loader, interval: Duration) -> Result { - let (reloadable, _) = Reloadable::init_load(loader) - .await - .map_err(|_| anyhow!("Failed to load the certificate"))?; + pub async fn init(loader: Loader, interval: Duration) -> Result { + let (reloadable, _) = Reloadable::init_load(loader).await?; let reloadable = Arc::new(reloadable); let cancel_token = CancellationToken::new(); diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 94fa77735ef..4472958efac 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -5,14 +5,14 @@ use std::{ task::{Context, Poll}, }; -use anyhow::Result; use n0_future::{Sink, Stream}; +use snafu::Snafu; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::Framed; use tokio_websockets::WebSocketStream; use crate::{ - protos::relay::{Frame, RelayCodec}, + protos::relay::{Frame, RecvError, RelayCodec}, KeyCache, }; @@ -68,12 +68,22 @@ impl Sink for RelayedStream { } } +/// Relay stream errors +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum StreamError { + #[snafu(transparent)] + Proto { source: RecvError }, + #[snafu(transparent)] + Ws { source: tokio_websockets::Error }, +} + impl Stream for RelayedStream { - type Item = anyhow::Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - Self::Relay(ref mut framed) => Pin::new(framed).poll_next(cx), + Self::Relay(ref mut framed) => Pin::new(framed).poll_next(cx).map_err(Into::into), Self::Ws(ref mut ws, ref cache) => match Pin::new(ws).poll_next(cx) { Poll::Ready(Some(Ok(msg))) => { if msg.is_close() { @@ -88,10 +98,9 @@ impl Stream for RelayedStream { ); return Poll::Pending; } - Poll::Ready(Some(Frame::decode_from_ws_msg( - msg.into_payload().into(), - cache, - ))) + let frame = Frame::decode_from_ws_msg(msg.into_payload().into(), cache) + .map_err(Into::into); + Poll::Ready(Some(frame)) } Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), Poll::Ready(None) => Poll::Ready(None), diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 660d4319500..d5802be1508 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -10,7 +10,7 @@ repository = "https://github.com/n0-computer/iroh" keywords = ["quic", "networking", "holepunching", "p2p"] # Sadly this also needs to be updated in .github/workflows/ci.yml -rust-version = "1.81" +rust-version = "1.85" [lib] # We need "cdylib" to actually generate .wasm files when we run with --target=wasm32-unknown-unknown. @@ -21,8 +21,8 @@ crate-type = ["lib", "cdylib"] workspace = true [dependencies] -aead = { version = "0.5.2", features = ["bytes"] } -anyhow = { version = "1" } +anyhow = "1.0.98" +aead = { version = "0.5.2", features = ["bytes", "std"] } atomic-waker = "1.1.2" concurrent-queue = "2.5" backon = { version = "1.4" } @@ -43,6 +43,8 @@ http = "1" iroh-base = { version = "0.35.0", default-features = false, features = ["key", "relay"], path = "../iroh-base" } iroh-relay = { version = "0.35", path = "../iroh-relay", default-features = false } n0-future = "0.1.2" +n0-snafu = "0.2.0" +nested_enum_utils = "0.2.1" netwatch = { version = "0.5" } pin-project = "1" pkarr = { version = "3.7", default-features = false, features = [ @@ -61,9 +63,9 @@ ring = "0.17" rustls = { version = "0.23", default-features = false, features = ["ring"] } serde = { version = "1.0.219", features = ["derive", "rc"] } smallvec = "1.11.1" +snafu = { version = "0.8.5", features = ["rust_1_81"] } strum = { version = "0.26", features = ["derive"] } stun-rs = "0.1.11" -thiserror = "2" tokio = { version = "1.44.1", features = [ "io-util", "macros", @@ -102,7 +104,7 @@ tracing-subscriber = { version = "0.3", features = [ "env-filter", ], optional = true } indicatif = { version = "0.17", features = ["tokio"], optional = true } -parse-size = { version = "=1.0.0", optional = true } # pinned version to avoid bumping msrv to 1.81 +parse-size = { version = "=1.0.0", optional = true, features = ['std'] } # pinned version to avoid bumping msrv to 1.81 # non-wasm-in-browser dependencies [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] @@ -135,7 +137,6 @@ getrandom = { version = "0.3.2", features = ["wasm_js"] } # target-common test/dev dependencies [dev-dependencies] postcard = { version = "1.1.1", features = ["use-std"] } -testresult = "0.4.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } # *non*-wasm-in-browser test/dev dependencies @@ -191,15 +192,19 @@ path = "tests/integration.rs" [[example]] name = "listen" +required-features = ["examples"] [[example]] name = "connect" +required-features = ["examples"] [[example]] name = "listen-unreliable" +required-features = ["examples"] [[example]] name = "connect-unreliable" +required-features = ["examples"] [[example]] name = "dht_discovery" diff --git a/iroh/bench/Cargo.toml b/iroh/bench/Cargo.toml index b274fe91f52..2b6a01d55e8 100644 --- a/iroh/bench/Cargo.toml +++ b/iroh/bench/Cargo.toml @@ -6,12 +6,12 @@ license = "MIT OR Apache-2.0" publish = false [dependencies] -anyhow = "1.0.22" bytes = "1.7" hdrhistogram = { version = "7.2", default-features = false } iroh = { path = ".." } iroh-metrics = "0.34" n0-future = "0.1.1" +n0-snafu = "0.2.0" quinn = { package = "iroh-quinn", version = "0.13" } rand = "0.8" rcgen = "0.13" diff --git a/iroh/bench/src/bin/bulk.rs b/iroh/bench/src/bin/bulk.rs index ac7ca7d982b..292ea89b697 100644 --- a/iroh/bench/src/bin/bulk.rs +++ b/iroh/bench/src/bin/bulk.rs @@ -1,11 +1,11 @@ use std::collections::BTreeMap; -use anyhow::Result; use clap::Parser; #[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))] use iroh_bench::quinn; use iroh_bench::{configure_tracing_subscriber, iroh, rt, s2n, Commands, Opt}; use iroh_metrics::{MetricValue, MetricsGroup}; +use n0_snafu::Result; fn main() { let cmd = Commands::parse(); diff --git a/iroh/bench/src/iroh.rs b/iroh/bench/src/iroh.rs index 35277bffd6a..24de8e518cd 100644 --- a/iroh/bench/src/iroh.rs +++ b/iroh/bench/src/iroh.rs @@ -3,13 +3,13 @@ use std::{ time::{Duration, Instant}, }; -use anyhow::{Context, Result}; use bytes::Bytes; use iroh::{ endpoint::{Connection, ConnectionError, RecvStream, SendStream, TransportConfig}, watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, }; +use n0_snafu::{Result, ResultExt}; use tracing::{trace, warn}; use crate::{ @@ -155,7 +155,7 @@ async fn drain_stream( let mut num_chunks: u64 = 0; if read_unordered { - while let Some(chunk) = stream.read_chunk(usize::MAX, false).await? { + while let Some(chunk) = stream.read_chunk(usize::MAX, false).await.e()? { if first_byte { ttfb = download_start.elapsed(); first_byte = false; @@ -177,7 +177,7 @@ async fn drain_stream( Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), ]; - while let Some(n) = stream.read_chunks(&mut bufs[..]).await? { + while let Some(n) = stream.read_chunks(&mut bufs[..]).await.e()? { if first_byte { ttfb = download_start.elapsed(); first_byte = false; @@ -276,7 +276,7 @@ pub async fn server(endpoint: Endpoint, opt: Opt) -> Result<()> { tokio::spawn(async move { drain_stream(&mut recv_stream, opt.read_unordered).await?; send_data_on_stream(&mut send_stream, opt.download_size).await?; - Ok::<_, anyhow::Error>(()) + Ok::<_, n0_snafu::Error>(()) }); } diff --git a/iroh/bench/src/lib.rs b/iroh/bench/src/lib.rs index a82e0cfa32b..4921793db69 100644 --- a/iroh/bench/src/lib.rs +++ b/iroh/bench/src/lib.rs @@ -5,8 +5,8 @@ use std::{ time::Instant, }; -use anyhow::Result; use clap::Parser; +use n0_snafu::Result; use stats::Stats; use tokio::{ runtime::{Builder, Runtime}, diff --git a/iroh/bench/src/quinn.rs b/iroh/bench/src/quinn.rs index 21028c44c8b..7f34836582c 100644 --- a/iroh/bench/src/quinn.rs +++ b/iroh/bench/src/quinn.rs @@ -4,8 +4,8 @@ use std::{ time::{Duration, Instant}, }; -use anyhow::{Context, Result}; use bytes::Bytes; +use n0_snafu::{Result, ResultExt}; use quinn::{ crypto::rustls::QuicClientConfig, Connection, Endpoint, RecvStream, SendStream, TransportConfig, }; @@ -73,7 +73,7 @@ pub async fn connect_client( quinn::Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(); let mut roots = RootCertStore::empty(); - roots.add(server_cert)?; + roots.add(server_cert).e()?; let provider = rustls::crypto::ring::default_provider(); @@ -83,7 +83,8 @@ pub async fn connect_client( .with_root_certificates(roots) .with_no_client_auth(); - let mut client_config = quinn::ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto)?)); + let mut client_config = + quinn::ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto).e()?)); client_config.transport_config(Arc::new(transport_config(opt.max_streams, opt.initial_mtu))); let connection = endpoint @@ -124,7 +125,7 @@ async fn drain_stream( let mut num_chunks: u64 = 0; if read_unordered { - while let Some(chunk) = stream.read_chunk(usize::MAX, false).await? { + while let Some(chunk) = stream.read_chunk(usize::MAX, false).await.e()? { if first_byte { ttfb = download_start.elapsed(); first_byte = false; @@ -146,7 +147,7 @@ async fn drain_stream( Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), ]; - while let Some(n) = stream.read_chunks(&mut bufs[..]).await? { + while let Some(n) = stream.read_chunks(&mut bufs[..]).await.e()? { if first_byte { ttfb = download_start.elapsed(); first_byte = false; @@ -245,7 +246,7 @@ pub async fn server(endpoint: Endpoint, opt: Opt) -> Result<()> { tokio::spawn(async move { drain_stream(&mut recv_stream, opt.read_unordered).await?; send_data_on_stream(&mut send_stream, opt.download_size).await?; - Ok::<_, anyhow::Error>(()) + Ok::<_, n0_snafu::Error>(()) }); } diff --git a/iroh/examples/0rtt.rs b/iroh/examples/0rtt.rs index 86efb15959f..945abdaf1b1 100644 --- a/iroh/examples/0rtt.rs +++ b/iroh/examples/0rtt.rs @@ -1,6 +1,5 @@ use std::{env, future::Future, str::FromStr, time::Instant}; -use anyhow::Context; use clap::Parser; use iroh::{ endpoint::{Connecting, Connection}, @@ -9,6 +8,7 @@ use iroh::{ }; use iroh_base::ticket::NodeTicket; use n0_future::{future, StreamExt}; +use n0_snafu::ResultExt; use rand::thread_rng; use tracing::{info, trace}; @@ -28,7 +28,7 @@ struct Args { /// Gets a secret key from the IROH_SECRET environment variable or generates a new random one. /// If the environment variable is set, it must be a valid string representation of a secret key. -pub fn get_or_generate_secret_key() -> anyhow::Result { +pub fn get_or_generate_secret_key() -> n0_snafu::Result { if let Ok(secret) = env::var("IROH_SECRET") { // Parse the secret key from string SecretKey::from_str(&secret).context("Invalid secret key format") @@ -50,28 +50,28 @@ async fn pingpong( connection: &Connection, proceed: impl Future, x: u64, -) -> anyhow::Result<()> { - let (mut send, recv) = connection.open_bi().await?; +) -> n0_snafu::Result<()> { + let (mut send, recv) = connection.open_bi().await.e()?; let data = x.to_be_bytes(); - send.write_all(&data).await?; - send.finish()?; + send.write_all(&data).await.e()?; + send.finish().e()?; let mut recv = if proceed.await { // use recv directly if we can proceed recv } else { // proceed returned false, so we have learned that the 0-RTT send was rejected. // at this point we have a fully handshaked connection, so we try again. - let (mut send, recv) = connection.open_bi().await?; - send.write_all(&data).await?; - send.finish()?; + let (mut send, recv) = connection.open_bi().await.e()?; + send.write_all(&data).await.e()?; + send.finish().e()?; recv }; - let echo = recv.read_to_end(8).await?; - anyhow::ensure!(echo == data); + let echo = recv.read_to_end(8).await.e()?; + assert!(echo == data); Ok(()) } -async fn pingpong_0rtt(connecting: Connecting, i: u64) -> anyhow::Result { +async fn pingpong_0rtt(connecting: Connecting, i: u64) -> n0_snafu::Result { let connection = match connecting.into_0rtt() { Ok((connection, accepted)) => { trace!("0-RTT possible from our side"); @@ -80,7 +80,7 @@ async fn pingpong_0rtt(connecting: Connecting, i: u64) -> anyhow::Result { trace!("0-RTT not possible from our side"); - let connection = connecting.await?; + let connection = connecting.await.e()?; pingpong(&connection, future::ready(true), i).await?; connection } @@ -88,7 +88,7 @@ async fn pingpong_0rtt(connecting: Connecting, i: u64) -> anyhow::Result anyhow::Result<()> { +async fn connect(args: Args) -> n0_snafu::Result<()> { let node_addr = args.node.unwrap().node_addr().clone(); let endpoint = iroh::Endpoint::builder() .relay_mode(iroh::RelayMode::Disabled) @@ -102,7 +102,7 @@ async fn connect(args: Args) -> anyhow::Result<()> { .connect_with_opts(node_addr.clone(), PINGPONG_ALPN, Default::default()) .await?; let connection = if args.disable_0rtt { - let connection = connecting.await?; + let connection = connecting.await.e()?; trace!("connecting without 0-RTT"); pingpong(&connection, future::ready(true), i).await?; connection @@ -128,7 +128,7 @@ async fn connect(args: Args) -> anyhow::Result<()> { Ok(()) } -async fn accept(_args: Args) -> anyhow::Result<()> { +async fn accept(_args: Args) -> n0_snafu::Result<()> { let secret_key = get_or_generate_secret_key()?; let endpoint = iroh::Endpoint::builder() .alpns(vec![PINGPONG_ALPN.to_vec()]) @@ -139,7 +139,7 @@ async fn accept(_args: Args) -> anyhow::Result<()> { let mut addrs = endpoint.node_addr().stream(); let addr = loop { let Some(addr) = addrs.next().await else { - anyhow::bail!("Address stream closed"); + snafu::whatever!("Address stream closed"); }; if let Some(addr) = addr { if !addr.direct_addresses.is_empty() { @@ -153,18 +153,18 @@ async fn accept(_args: Args) -> anyhow::Result<()> { let accept = async move { while let Some(incoming) = endpoint.accept().await { tokio::spawn(async move { - let connecting = incoming.accept()?; + let connecting = incoming.accept().e()?; let (connection, _zero_rtt_accepted) = connecting .into_0rtt() .expect("accept into 0.5 RTT always succeeds"); - let (mut send, mut recv) = connection.accept_bi().await?; + let (mut send, mut recv) = connection.accept_bi().await.e()?; trace!("recv.is_0rtt: {}", recv.is_0rtt()); - let data = recv.read_to_end(8).await?; + let data = recv.read_to_end(8).await.e()?; trace!("recv: {}", data.len()); - send.write_all(&data).await?; - send.finish()?; + send.write_all(&data).await.e()?; + send.finish().e()?; connection.closed().await; - anyhow::Ok(()) + Ok::<_, n0_snafu::Error>(()) }); } }; @@ -180,7 +180,7 @@ async fn accept(_args: Args) -> anyhow::Result<()> { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> n0_snafu::Result<()> { tracing_subscriber::fmt::init(); let args = Args::parse(); if args.node.is_some() { diff --git a/iroh/examples/connect-unreliable.rs b/iroh/examples/connect-unreliable.rs index b361b7e2687..bc40be070c6 100644 --- a/iroh/examples/connect-unreliable.rs +++ b/iroh/examples/connect-unreliable.rs @@ -9,6 +9,7 @@ use std::net::SocketAddr; use clap::Parser; use iroh::{watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; +use n0_snafu::ResultExt; use tracing::info; // An example ALPN that we are using to communicate over the `Endpoint` @@ -28,7 +29,7 @@ struct Cli { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> n0_snafu::Result<()> { tracing_subscriber::fmt::init(); println!("\nconnect (unreliable) example!\n"); let args = Cli::parse(); @@ -72,11 +73,11 @@ async fn main() -> anyhow::Result<()> { // Send a datagram over the connection. let message = format!("{me} is saying 'hello!'"); - conn.send_datagram(message.as_bytes().to_vec().into())?; + conn.send_datagram(message.as_bytes().to_vec().into()).e()?; // Read a datagram over the connection. - let message = conn.read_datagram().await?; - let message = String::from_utf8(message.into())?; + let message = conn.read_datagram().await.e()?; + let message = String::from_utf8(message.into()).e()?; println!("received: {message}"); Ok(()) diff --git a/iroh/examples/connect.rs b/iroh/examples/connect.rs index 3b0a247ae53..faa676cd0a2 100644 --- a/iroh/examples/connect.rs +++ b/iroh/examples/connect.rs @@ -7,9 +7,9 @@ //! Run the `listen` example first (`iroh/examples/listen.rs`), which will give you instructions on how to run this example to watch two nodes connect and exchange bytes. use std::net::SocketAddr; -use anyhow::Context; use clap::Parser; use iroh::{watcher::Watcher as _, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey}; +use n0_snafu::{Result, ResultExt}; use tracing::info; // An example ALPN that we are using to communicate over the `Endpoint` @@ -29,7 +29,7 @@ struct Cli { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<()> { tracing_subscriber::fmt::init(); println!("\nconnect example!\n"); let args = Cli::parse(); @@ -78,15 +78,15 @@ async fn main() -> anyhow::Result<()> { info!("connected"); // Use the Quinn API to send and recv content. - let (mut send, mut recv) = conn.open_bi().await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; let message = format!("{me} is saying 'hello!'"); - send.write_all(message.as_bytes()).await?; + send.write_all(message.as_bytes()).await.e()?; // Call `finish` to close the send side of the connection gracefully. - send.finish()?; - let message = recv.read_to_end(100).await?; - let message = String::from_utf8(message)?; + send.finish().e()?; + let message = recv.read_to_end(100).await.e()?; + let message = String::from_utf8(message).e()?; println!("received: {message}"); // We received the last message: close all connections and allow for the close diff --git a/iroh/examples/dht_discovery.rs b/iroh/examples/dht_discovery.rs index c514474d766..64945293530 100644 --- a/iroh/examples/dht_discovery.rs +++ b/iroh/examples/dht_discovery.rs @@ -12,6 +12,7 @@ use std::str::FromStr; use clap::Parser; use iroh::{Endpoint, NodeId}; +use n0_snafu::ResultExt; use tracing::warn; use url::Url; @@ -40,7 +41,7 @@ enum PkarrRelay { } impl FromStr for PkarrRelay { - type Err = anyhow::Error; + type Err = url::ParseError; fn from_str(s: &str) -> Result { match s { @@ -60,7 +61,7 @@ fn build_discovery(args: Args) -> iroh::discovery::pkarr::dht::Builder { } } -async fn chat_server(args: Args) -> anyhow::Result<()> { +async fn chat_server(args: Args) -> n0_snafu::Result<()> { let secret_key = iroh::SecretKey::generate(rand::rngs::OsRng); let node_id = secret_key.public(); let discovery = build_discovery(args) @@ -72,7 +73,7 @@ async fn chat_server(args: Args) -> anyhow::Result<()> { .discovery(Box::new(discovery)) .bind() .await?; - let zid = pkarr::PublicKey::try_from(node_id.as_bytes())?.to_z32(); + let zid = pkarr::PublicKey::try_from(node_id.as_bytes()).e()?.to_z32(); println!("Listening on {}", node_id); println!("pkarr z32: {}", zid); println!("see https://app.pkarr.org/?pk={}", zid); @@ -87,11 +88,11 @@ async fn chat_server(args: Args) -> anyhow::Result<()> { } }; tokio::spawn(async move { - let connection = connecting.await?; + let connection = connecting.await.e()?; let remote_node_id = connection.remote_node_id()?; println!("got connection from {}", remote_node_id); // just leave the tasks hanging. this is just an example. - let (mut writer, mut reader) = connection.accept_bi().await?; + let (mut writer, mut reader) = connection.accept_bi().await.e()?; let _copy_to_stdout = tokio::spawn(async move { tokio::io::copy(&mut reader, &mut tokio::io::stdout()).await }); @@ -99,13 +100,13 @@ async fn chat_server(args: Args) -> anyhow::Result<()> { tokio::spawn( async move { tokio::io::copy(&mut tokio::io::stdin(), &mut writer).await }, ); - anyhow::Ok(()) + Ok::<_, n0_snafu::Error>(()) }); } Ok(()) } -async fn chat_client(args: Args) -> anyhow::Result<()> { +async fn chat_client(args: Args) -> n0_snafu::Result<()> { let remote_node_id = args.node_id.unwrap(); let secret_key = iroh::SecretKey::generate(rand::rngs::OsRng); let node_id = secret_key.public(); @@ -120,18 +121,18 @@ async fn chat_client(args: Args) -> anyhow::Result<()> { println!("We are {} and connecting to {}", node_id, remote_node_id); let connection = endpoint.connect(remote_node_id, CHAT_ALPN).await?; println!("connected to {}", remote_node_id); - let (mut writer, mut reader) = connection.open_bi().await?; + let (mut writer, mut reader) = connection.open_bi().await.e()?; let _copy_to_stdout = tokio::spawn(async move { tokio::io::copy(&mut reader, &mut tokio::io::stdout()).await }); let _copy_from_stdin = tokio::spawn(async move { tokio::io::copy(&mut tokio::io::stdin(), &mut writer).await }); - _copy_to_stdout.await??; - _copy_from_stdin.await??; + _copy_to_stdout.await.e()?.e()?; + _copy_from_stdin.await.e()?.e()?; Ok(()) } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> n0_snafu::Result<()> { tracing_subscriber::fmt::init(); let args = Args::parse(); if args.node_id.is_some() { diff --git a/iroh/examples/echo-no-router.rs b/iroh/examples/echo-no-router.rs index b9348ed9c95..3b9dc50bc5d 100644 --- a/iroh/examples/echo-no-router.rs +++ b/iroh/examples/echo-no-router.rs @@ -7,8 +7,8 @@ //! //! cargo run --example echo-no-router --features=examples -use anyhow::Result; use iroh::{watcher::Watcher as _, Endpoint, NodeAddr}; +use n0_snafu::{Error, Result, ResultExt}; /// Each protocol is identified by its ALPN string. /// @@ -37,16 +37,16 @@ async fn connect_side(addr: NodeAddr) -> Result<()> { let conn = endpoint.connect(addr, ALPN).await?; // Open a bidirectional QUIC stream - let (mut send, mut recv) = conn.open_bi().await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; // Send some data to be echoed - send.write_all(b"Hello, world!").await?; + send.write_all(b"Hello, world!").await.e()?; // Signal the end of data for this particular stream - send.finish()?; + send.finish().e()?; // Receive the echo, but limit reading up to maximum 1000 bytes - let response = recv.read_to_end(1000).await?; + let response = recv.read_to_end(1000).await.e()?; assert_eq!(&response, b"Hello, world!"); // Explicitly close the whole connection. @@ -85,7 +85,7 @@ async fn start_accept_side() -> Result { while let Some(incoming) = endpoint.accept().await { // spawn a task for each incoming connection, so we can serve multiple connections asynchronously tokio::spawn(async move { - let connection = incoming.await?; + let connection = incoming.await.e()?; // We can get the remote's node id from the connection. let node_id = connection.remote_node_id()?; @@ -93,26 +93,26 @@ async fn start_accept_side() -> Result { // Our protocol is a simple request-response protocol, so we expect the // connecting peer to open a single bi-directional stream. - let (mut send, mut recv) = connection.accept_bi().await?; + let (mut send, mut recv) = connection.accept_bi().await.e()?; // Echo any bytes received back directly. // This will keep copying until the sender signals the end of data on the stream. - let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; + let bytes_sent = tokio::io::copy(&mut recv, &mut send).await.e()?; println!("Copied over {bytes_sent} byte(s)"); // By calling `finish` on the send stream we signal that we will not send anything // further, which makes the receive stream on the other end terminate. - send.finish()?; + send.finish().e()?; // Wait until the remote closes the connection, which it does once it // received the response. connection.closed().await; - anyhow::Ok(()) + Ok::<_, Error>(()) }); } - anyhow::Ok(()) + Ok::<_, Error>(()) } }); diff --git a/iroh/examples/echo.rs b/iroh/examples/echo.rs index e02ee825c3a..2aa2e3d86ab 100644 --- a/iroh/examples/echo.rs +++ b/iroh/examples/echo.rs @@ -6,13 +6,13 @@ //! //! cargo run --example echo --features=examples -use anyhow::Result; use iroh::{ endpoint::Connection, - protocol::{ProtocolHandler, Router}, + protocol::{ProtocolError, ProtocolHandler, Router}, watcher::Watcher as _, Endpoint, NodeAddr, }; +use n0_snafu::{Result, ResultExt}; /// Each protocol is identified by its ALPN string. /// @@ -28,7 +28,7 @@ async fn main() -> Result<()> { connect_side(node_addr).await?; // This makes sure the endpoint in the router is closed properly and connections close gracefully - router.shutdown().await?; + router.shutdown().await.e()?; Ok(()) } @@ -40,16 +40,16 @@ async fn connect_side(addr: NodeAddr) -> Result<()> { let conn = endpoint.connect(addr, ALPN).await?; // Open a bidirectional QUIC stream - let (mut send, mut recv) = conn.open_bi().await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; // Send some data to be echoed - send.write_all(b"Hello, world!").await?; + send.write_all(b"Hello, world!").await.e()?; // Signal the end of data for this particular stream - send.finish()?; + send.finish().e()?; // Receive the echo, but limit reading up to maximum 1000 bytes - let response = recv.read_to_end(1000).await?; + let response = recv.read_to_end(1000).await.e()?; assert_eq!(&response, b"Hello, world!"); // Explicitly close the whole connection. @@ -83,7 +83,7 @@ impl ProtocolHandler for Echo { /// /// The returned future runs on a newly spawned tokio task, so it can run as long as /// the connection lasts. - async fn accept(&self, connection: Connection) -> Result<()> { + async fn accept(&self, connection: Connection) -> Result<(), ProtocolError> { // We can get the remote's node id from the connection. let node_id = connection.remote_node_id()?; println!("accepted connection from {node_id}"); diff --git a/iroh/examples/listen-unreliable.rs b/iroh/examples/listen-unreliable.rs index 9e164222ca2..586fdded8a2 100644 --- a/iroh/examples/listen-unreliable.rs +++ b/iroh/examples/listen-unreliable.rs @@ -4,13 +4,14 @@ //! run this example from the project root: //! $ cargo run --example listen-unreliable use iroh::{watcher::Watcher as _, Endpoint, RelayMode, SecretKey}; +use n0_snafu::{Error, Result, ResultExt}; use tracing::{info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` const EXAMPLE_ALPN: &[u8] = b"n0/iroh/examples/magic/0"; #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<()> { tracing_subscriber::fmt::init(); println!("\nlisten (unreliable) example!\n"); let secret_key = SecretKey::generate(rand::rngs::OsRng); @@ -68,7 +69,7 @@ async fn main() -> anyhow::Result<()> { } }; let alpn = connecting.alpn().await?; - let conn = connecting.await?; + let conn = connecting.await.e()?; let node_id = conn.remote_node_id()?; info!( "new (unreliable) connection from {node_id} with ALPN {}", @@ -78,14 +79,14 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { // use the `quinn` API to read a datagram off the connection, and send a datagra, in return while let Ok(message) = conn.read_datagram().await { - let message = String::from_utf8(message.into())?; + let message = String::from_utf8(message.into()).e()?; println!("received: {message}"); let message = format!("hi! you connected to {me}. bye bye"); - conn.send_datagram(message.as_bytes().to_vec().into())?; + conn.send_datagram(message.as_bytes().to_vec().into()).e()?; } - Ok::<_, anyhow::Error>(()) + Ok::<_, Error>(()) }); } // stop with SIGINT (ctrl-c) diff --git a/iroh/examples/listen.rs b/iroh/examples/listen.rs index b327d38fb74..ac1690c4cf5 100644 --- a/iroh/examples/listen.rs +++ b/iroh/examples/listen.rs @@ -6,13 +6,14 @@ use std::time::Duration; use iroh::{endpoint::ConnectionError, watcher::Watcher as _, Endpoint, RelayMode, SecretKey}; +use n0_snafu::ResultExt; use tracing::{debug, info, warn}; // An example ALPN that we are using to communicate over the `Endpoint` const EXAMPLE_ALPN: &[u8] = b"n0/iroh/examples/magic/0"; #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> n0_snafu::Result<()> { tracing_subscriber::fmt::init(); println!("\nlisten example!\n"); let secret_key = SecretKey::generate(rand::rngs::OsRng); @@ -68,7 +69,7 @@ async fn main() -> anyhow::Result<()> { } }; let alpn = connecting.alpn().await?; - let conn = connecting.await?; + let conn = connecting.await.e()?; let node_id = conn.remote_node_id()?; info!( "new connection from {node_id} with ALPN {}", @@ -79,16 +80,16 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { // accept a bi-directional QUIC connection // use the `quinn` APIs to send and recv content - let (mut send, mut recv) = conn.accept_bi().await?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; debug!("accepted bi stream, waiting for data..."); - let message = recv.read_to_end(100).await?; - let message = String::from_utf8(message)?; + let message = recv.read_to_end(100).await.e()?; + let message = String::from_utf8(message).e()?; println!("received: {message}"); let message = format!("hi! you connected to {me}. bye bye"); - send.write_all(message.as_bytes()).await?; + send.write_all(message.as_bytes()).await.e()?; // call `finish` to close the connection gracefully - send.finish()?; + send.finish().e()?; // We sent the last message, so wait for the client to close the connection once // it received this message. @@ -102,7 +103,7 @@ async fn main() -> anyhow::Result<()> { if res.is_err() { println!("node {node_id} did not disconnect within 3 seconds"); } - Ok::<_, anyhow::Error>(()) + Ok::<_, n0_snafu::Error>(()) }); } // stop with SIGINT (ctrl-c) diff --git a/iroh/examples/locally-discovered-nodes.rs b/iroh/examples/locally-discovered-nodes.rs index aca9da3b0b1..7d691e37a03 100644 --- a/iroh/examples/locally-discovered-nodes.rs +++ b/iroh/examples/locally-discovered-nodes.rs @@ -5,9 +5,9 @@ //! This is an async, non-determinate process, so the number of NodeIDs discovered each time may be different. If you have other iroh endpoints or iroh nodes with [`MdnsDiscovery`] enabled, it may discover those nodes as well. use std::time::Duration; -use anyhow::Result; use iroh::{node_info::UserData, Endpoint, NodeId}; use n0_future::StreamExt; +use n0_snafu::Result; use tokio::task::JoinSet; #[tokio::main] @@ -66,7 +66,7 @@ async fn main() -> Result<()> { ep.set_user_data_for_discovery(Some(ud)); tokio::time::sleep(Duration::from_secs(3)).await; ep.close().await; - anyhow::Ok(()) + Ok::<_, n0_snafu::Error>(()) }); } diff --git a/iroh/examples/search.rs b/iroh/examples/search.rs index e007718dffa..7c9a9798b8c 100644 --- a/iroh/examples/search.rs +++ b/iroh/examples/search.rs @@ -31,13 +31,13 @@ use std::{collections::BTreeSet, sync::Arc}; -use anyhow::Result; use clap::Parser; use iroh::{ endpoint::Connection, - protocol::{ProtocolHandler, Router}, + protocol::{ProtocolError, ProtocolHandler, Router}, Endpoint, NodeId, }; +use n0_snafu::{Result, ResultExt}; use tokio::sync::Mutex; use tracing_subscriber::{prelude::*, EnvFilter}; @@ -97,7 +97,7 @@ async fn main() -> Result<()> { } // Wait for Ctrl-C to be pressed. - tokio::signal::ctrl_c().await?; + tokio::signal::ctrl_c().await.e()?; } Command::Query { node_id, query } => { // Query the remote node. @@ -110,7 +110,7 @@ async fn main() -> Result<()> { } } - router.shutdown().await?; + router.shutdown().await.e()?; Ok(()) } @@ -126,7 +126,7 @@ impl ProtocolHandler for BlobSearch { /// /// The returned future runs on a newly spawned tokio task, so it can run as long as /// the connection lasts. - async fn accept(&self, connection: Connection) -> Result<()> { + async fn accept(&self, connection: Connection) -> Result<(), ProtocolError> { // We can get the remote's node id from the connection. let node_id = connection.remote_node_id()?; println!("accepted connection from {node_id}"); @@ -136,16 +136,21 @@ impl ProtocolHandler for BlobSearch { let (mut send, mut recv) = connection.accept_bi().await?; // We read the query from the receive stream, while enforcing a max query length. - let query_bytes = recv.read_to_end(64).await?; + let query_bytes = recv + .read_to_end(64) + .await + .map_err(ProtocolError::from_err)?; // Now, we can perform the actual query on our local database. - let query = String::from_utf8(query_bytes)?; + let query = String::from_utf8(query_bytes).map_err(ProtocolError::from_err)?; let num_matches = self.query_local(&query).await; // We want to return a list of hashes. We do the simplest thing possible, and just send // one hash after the other. Because the hashes have a fixed size of 32 bytes, this is // very easy to parse on the other end. - send.write_all(&num_matches.to_le_bytes()).await?; + send.write_all(&num_matches.to_le_bytes()) + .await + .map_err(ProtocolError::from_err)?; // By calling `finish` on the send stream we signal that we will not send anything // further, which makes the receive stream on the other end terminate. @@ -176,21 +181,21 @@ impl BlobSearch { let conn = self.endpoint.connect(node_id, ALPN).await?; // Open a bi-directional in our connection. - let (mut send, mut recv) = conn.open_bi().await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; // Send our query. - send.write_all(query.as_bytes()).await?; + send.write_all(query.as_bytes()).await.e()?; // Finish the send stream, signalling that no further data will be sent. // This makes the `read_to_end` call on the accepting side terminate. - send.finish()?; + send.finish().e()?; // The response is a 64 bit integer // We simply read it into a byte buffer. let mut num_matches = [0u8; 8]; // Read 8 bytes from the stream. - recv.read_exact(&mut num_matches).await?; + recv.read_exact(&mut num_matches).await.e()?; let num_matches = u64::from_le_bytes(num_matches); diff --git a/iroh/examples/transfer.rs b/iroh/examples/transfer.rs index 7b1cfa82a3e..21ecad77ce0 100644 --- a/iroh/examples/transfer.rs +++ b/iroh/examples/transfer.rs @@ -3,7 +3,6 @@ use std::{ time::{Duration, Instant}, }; -use anyhow::{Context, Result}; use bytes::Bytes; use clap::{Parser, Subcommand}; use indicatif::HumanBytes; @@ -19,6 +18,7 @@ use iroh::{ }; use iroh_base::ticket::NodeTicket; use n0_future::task::AbortOnDropHandle; +use n0_snafu::{Result, ResultExt}; use tokio_stream::StreamExt; use tracing::{info, warn}; use url::Url; @@ -291,7 +291,7 @@ async fn provide(endpoint: Endpoint, size: u64) -> Result<()> { // spawn a task to handle reading and writing off of the connection let endpoint_clone = endpoint.clone(); tokio::spawn(async move { - let conn = connecting.await?; + let conn = connecting.await.e()?; let node_id = conn.remote_node_id()?; info!( "new connection from {node_id} with ALPN {}", @@ -306,10 +306,10 @@ async fn provide(endpoint: Endpoint, size: u64) -> Result<()> { // accept a bi-directional QUIC connection // use the `quinn` APIs to send and recv content - let (mut send, mut recv) = conn.accept_bi().await?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; tracing::debug!("accepted bi stream, waiting for data..."); - let message = recv.read_to_end(100).await?; - let message = String::from_utf8(message)?; + let message = recv.read_to_end(100).await.e()?; + let message = String::from_utf8(message).e()?; println!("[{remote}] Received: \"{message}\""); let start = Instant::now(); @@ -338,7 +338,7 @@ async fn provide(endpoint: Endpoint, size: u64) -> Result<()> { } else { println!("[{remote}] Disconnected"); } - Ok::<_, anyhow::Error>(()) + Ok::<_, n0_snafu::Error>(()) }); } @@ -362,19 +362,21 @@ async fn fetch(endpoint: Endpoint, ticket: &str) -> Result<()> { let _guard = watch_conn_type(&endpoint, remote_node_id); // Use the Quinn API to send and recv content. - let (mut send, mut recv) = conn.open_bi().await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; let message = format!("{me} is saying hello!"); - send.write_all(message.as_bytes()).await?; + send.write_all(message.as_bytes()).await.e()?; // Call `finish` to signal no more data will be sent on this stream. - send.finish()?; + send.finish().e()?; println!("Sent: \"{message}\""); let (len, time_to_first_byte, chnk) = drain_stream(&mut recv, false).await?; // We received the last message: close all connections and allow for the close // message to be sent. - tokio::time::timeout(Duration::from_secs(3), endpoint.close()).await?; + tokio::time::timeout(Duration::from_secs(3), endpoint.close()) + .await + .e()?; let duration = start.elapsed(); println!( @@ -401,7 +403,7 @@ async fn drain_stream( let mut num_chunks: u64 = 0; if read_unordered { - while let Some(chunk) = stream.read_chunk(usize::MAX, false).await? { + while let Some(chunk) = stream.read_chunk(usize::MAX, false).await.e()? { if first_byte { time_to_first_byte = download_start.elapsed(); first_byte = false; @@ -423,7 +425,7 @@ async fn drain_stream( Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), ]; - while let Some(n) = stream.read_chunks(&mut bufs[..]).await? { + while let Some(n) = stream.read_chunks(&mut bufs[..]).await.e()? { if first_byte { time_to_first_byte = download_start.elapsed(); first_byte = false; @@ -469,9 +471,9 @@ async fn send_data_on_stream( Ok(()) } -fn parse_byte_size(s: &str) -> Result { +fn parse_byte_size(s: &str) -> std::result::Result { let cfg = parse_size::Config::new().with_binary(); - cfg.parse_size(s).map_err(|e| anyhow::anyhow!(e)) + cfg.parse_size(s) } fn watch_conn_type(endpoint: &Endpoint, node_id: NodeId) -> AbortOnDropHandle<()> { diff --git a/iroh/src/disco.rs b/iroh/src/disco.rs index f1beaedf778..f33bf2e2e29 100644 --- a/iroh/src/disco.rs +++ b/iroh/src/disco.rs @@ -23,10 +23,11 @@ use std::{ net::{IpAddr, SocketAddr}, }; -use anyhow::{anyhow, bail, ensure, Context, Result}; use data_encoding::HEXLOWER; use iroh_base::{PublicKey, RelayUrl}; +use nested_enum_utils::common_fields; use serde::{Deserialize, Serialize}; +use snafu::{ensure, Snafu}; use url::Url; // TODO: custom magicn @@ -202,13 +203,12 @@ pub struct CallMeMaybe { } impl Ping { - fn from_bytes(ver: u8, p: &[u8]) -> Result { - ensure!(ver == V0, "invalid version"); + fn from_bytes(p: &[u8]) -> Result { // Deliberately lax on longer-than-expected messages, for future compatibility. - ensure!(p.len() >= PING_LEN, "message too short"); + ensure!(p.len() >= PING_LEN, TooShortSnafu); let tx_id: [u8; TX_LEN] = p[..TX_LEN].try_into().expect("length checked"); let raw_key = &p[TX_LEN..TX_LEN + iroh_base::PublicKey::LENGTH]; - let node_key = PublicKey::try_from(raw_key)?; + let node_key = PublicKey::try_from(raw_key).map_err(|_| InvalidEncodingSnafu.build())?; let tx_id = stun_rs::TransactionId::from(tx_id); Ok(Ping { tx_id, node_key }) @@ -226,22 +226,37 @@ impl Ping { } } -fn send_addr_from_bytes(p: &[u8]) -> Result { - ensure!(p.len() > 2, "too short"); +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ParseError { + #[snafu(display("message is too short"))] + TooShort {}, + #[snafu(display("invalid encoding"))] + InvalidEncoding {}, + #[snafu(display("unknown format"))] + UnknownFormat {}, +} + +fn send_addr_from_bytes(p: &[u8]) -> Result { + ensure!(p.len() > 2, TooShortSnafu); match p[0] { 0u8 => { - let bytes: [u8; EP_LENGTH] = p[1..].try_into().context("invalid length")?; + let bytes: [u8; EP_LENGTH] = p[1..].try_into().map_err(|_| TooShortSnafu.build())?; let addr = socket_addr_from_bytes(bytes); Ok(SendAddr::Udp(addr)) } 1u8 => { - let s = std::str::from_utf8(&p[1..])?; - let u: Url = s.parse()?; + let s = std::str::from_utf8(&p[1..]).map_err(|_| InvalidEncodingSnafu.build())?; + let u: Url = s.parse().map_err(|_| InvalidEncodingSnafu.build())?; Ok(SendAddr::Relay(u.into())) } - _ => { - bail!("invalid addr type {}", p[0]); - } + _ => Err(UnknownFormatSnafu.build()), } } @@ -286,9 +301,9 @@ fn socket_addr_as_bytes(addr: &SocketAddr) -> [u8; EP_LENGTH] { } impl Pong { - fn from_bytes(ver: u8, p: &[u8]) -> Result { - ensure!(ver == V0, "invalid version"); - let tx_id: [u8; TX_LEN] = p[..TX_LEN].try_into().context("message too short")?; + fn from_bytes(p: &[u8]) -> Result { + let tx_id: [u8; TX_LEN] = p[..TX_LEN].try_into().map_err(|_| TooShortSnafu.build())?; + let tx_id = stun_rs::TransactionId::from(tx_id); let src = send_addr_from_bytes(&p[TX_LEN..])?; @@ -310,9 +325,8 @@ impl Pong { } impl CallMeMaybe { - fn from_bytes(ver: u8, p: &[u8]) -> Result { - ensure!(ver == V0, "invalid version"); - ensure!(p.len() % EP_LENGTH == 0, "invalid entries"); + fn from_bytes(p: &[u8]) -> Result { + ensure!(p.len() % EP_LENGTH == 0, InvalidEncodingSnafu); let num_entries = p.len() / EP_LENGTH; let mut m = CallMeMaybe { @@ -320,7 +334,8 @@ impl CallMeMaybe { }; for chunk in p.chunks_exact(EP_LENGTH) { - let bytes: [u8; EP_LENGTH] = chunk.try_into().context("chunk must match")?; + let bytes: [u8; EP_LENGTH] = + chunk.try_into().map_err(|_| InvalidEncodingSnafu.build())?; let src = socket_addr_from_bytes(bytes); m.my_numbers.push(src); } @@ -348,23 +363,25 @@ impl CallMeMaybe { impl Message { /// Parses the encrypted part of the message from inside the nacl secretbox. - pub fn from_bytes(p: &[u8]) -> Result { - ensure!(p.len() >= 2, "message too short"); + pub fn from_bytes(p: &[u8]) -> Result { + ensure!(p.len() >= 2, TooShortSnafu); + + let t = MessageType::try_from(p[0]).map_err(|_| UnknownFormatSnafu.build())?; + let version = p[1]; + ensure!(version == V0, UnknownFormatSnafu); - let t = MessageType::try_from(p[0]).map_err(|v| anyhow!("unknown message type: {}", v))?; - let ver = p[1]; let p = &p[2..]; match t { MessageType::Ping => { - let ping = Ping::from_bytes(ver, p)?; + let ping = Ping::from_bytes(p)?; Ok(Message::Ping(ping)) } MessageType::Pong => { - let pong = Pong::from_bytes(ver, p)?; + let pong = Pong::from_bytes(p)?; Ok(Message::Pong(pong)) } MessageType::CallMeMaybe => { - let cm = CallMeMaybe::from_bytes(ver, p)?; + let cm = CallMeMaybe::from_bytes(p)?; Ok(Message::CallMeMaybe(cm)) } } diff --git a/iroh/src/discovery.rs b/iroh/src/discovery.rs index d7b0a415408..228b488cef9 100644 --- a/iroh/src/discovery.rs +++ b/iroh/src/discovery.rs @@ -54,7 +54,7 @@ //! Endpoint, SecretKey, //! }; //! -//! # async fn wrapper() -> anyhow::Result<()> { +//! # async fn wrapper() -> n0_snafu::Result<()> { //! let secret_key = SecretKey::generate(rand::rngs::OsRng); //! let discovery = ConcurrentDiscovery::from_services(vec![ //! Box::new(PkarrPublisher::n0_dns(secret_key.clone())), @@ -81,7 +81,7 @@ //! # use iroh::discovery::ConcurrentDiscovery; //! # use iroh::SecretKey; //! # -//! # async fn wrapper() -> anyhow::Result<()> { +//! # async fn wrapper() -> n0_snafu::Result<()> { //! # let secret_key = SecretKey::generate(rand::rngs::OsRng); //! let discovery = ConcurrentDiscovery::from_services(vec![ //! Box::new(PkarrPublisher::n0_dns(secret_key.clone())), @@ -107,8 +107,8 @@ use std::sync::Arc; -use anyhow::{anyhow, ensure, Result}; -use iroh_base::{NodeAddr, NodeId}; +use iroh_base::{NodeAddr, NodeId, PublicKey}; +use iroh_relay::node_info::EncodingError; use n0_future::{ boxed::BoxStream, stream::StreamExt, @@ -116,10 +116,12 @@ use n0_future::{ time::{self, Duration}, Stream, TryStreamExt, }; +use nested_enum_utils::common_fields; +use snafu::{ensure, IntoError, Snafu}; use tokio::sync::oneshot; use tracing::{debug, error_span, warn, Instrument}; -pub use crate::node_info::{NodeData, NodeInfo, UserData}; +pub use crate::node_info::{NodeData, NodeInfo, ParseError, UserData}; use crate::Endpoint; #[cfg(not(wasm_browser))] @@ -130,6 +132,57 @@ pub mod mdns; pub mod pkarr; pub mod static_provider; +/// Discovery errors +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum DiscoveryError { + #[snafu(display("No discovery service configured"))] + NoServiceConfigured {}, + #[snafu(display("Cannot resolve node id"))] + NodeId { node_id: PublicKey }, + #[snafu(display("Discovery produced no results"))] + NoResults { node_id: PublicKey }, + #[snafu(display("Error encoding the signed packet"))] + SignedPacket { + #[snafu(source(from(EncodingError, Box::new)))] + source: Box, + }, + #[snafu(display("Error parsing the signed packet"))] + ParsePacket { + #[snafu(source(from(ParseError, Box::new)))] + source: Box, + }, + #[snafu(display("Service '{provenance}' error"))] + User { + provenance: &'static str, + source: Box, + }, +} + +impl DiscoveryError { + /// Creates a new user error from an arbitrary error type. + pub fn from_err( + provenance: &'static str, + source: T, + ) -> Self { + UserSnafu { provenance }.into_error(Box::new(source)) + } + + /// Creates a new user error from an arbitrary boxed error type. + pub fn from_err_box( + provenance: &'static str, + source: Box, + ) -> Self { + UserSnafu { provenance }.into_error(source) + } +} + /// Node discovery for [`super::Endpoint`]. /// /// This trait defines publishing and resolving addressing information for a [`NodeId`]. @@ -164,7 +217,7 @@ pub trait Discovery: std::fmt::Debug + Send + Sync { &self, _endpoint: Endpoint, _node_id: NodeId, - ) -> Option>> { + ) -> Option>> { None } @@ -330,7 +383,7 @@ impl Discovery for ConcurrentDiscovery { &self, endpoint: Endpoint, node_id: NodeId, - ) -> Option>> { + ) -> Option>> { let streams = self .services .iter() @@ -359,14 +412,14 @@ const MAX_AGE: Duration = Duration::from_secs(10); /// A wrapper around a tokio task which runs a node discovery. pub(super) struct DiscoveryTask { - on_first_rx: oneshot::Receiver>, + on_first_rx: oneshot::Receiver>, _task: AbortOnDropHandle<()>, } impl DiscoveryTask { /// Starts a discovery task. - pub(super) fn start(ep: Endpoint, node_id: NodeId) -> Result { - ensure!(ep.discovery().is_some(), "No discovery services configured"); + pub(super) fn start(ep: Endpoint, node_id: NodeId) -> Result { + ensure!(ep.discovery().is_some(), NoServiceConfiguredSnafu); let (on_first_tx, on_first_rx) = oneshot::channel(); let me = ep.node_id(); let task = task::spawn( @@ -392,12 +445,12 @@ impl DiscoveryTask { ep: &Endpoint, node_id: NodeId, delay: Option, - ) -> Result> { + ) -> Result, DiscoveryError> { // If discovery is not needed, don't even spawn a task. if !Self::needs_discovery(ep, node_id) { return Ok(None); } - ensure!(ep.discovery().is_some(), "No discovery services configured"); + ensure!(ep.discovery().is_some(), NoServiceConfiguredSnafu); let (on_first_tx, on_first_rx) = oneshot::channel(); let ep = ep.clone(); let me = ep.node_id(); @@ -425,19 +478,20 @@ impl DiscoveryTask { } /// Waits until the discovery task produced at least one result. - pub(super) async fn first_arrived(&mut self) -> Result<()> { + pub(super) async fn first_arrived(&mut self) -> Result<(), DiscoveryError> { let fut = &mut self.on_first_rx; - fut.await??; + fut.await.expect("sender dropped")?; Ok(()) } - fn create_stream(ep: &Endpoint, node_id: NodeId) -> Result>> { - let discovery = ep - .discovery() - .ok_or_else(|| anyhow!("No discovery service configured"))?; + fn create_stream( + ep: &Endpoint, + node_id: NodeId, + ) -> Result>, DiscoveryError> { + let discovery = ep.discovery().ok_or(NoServiceConfiguredSnafu.build())?; let stream = discovery .resolve(ep.clone(), node_id) - .ok_or_else(|| anyhow!("No discovery service can resolve node {node_id}",))?; + .ok_or(NodeIdSnafu { node_id }.build())?; Ok(stream) } @@ -465,7 +519,11 @@ impl DiscoveryTask { } } - async fn run(ep: Endpoint, node_id: NodeId, on_first_tx: oneshot::Sender>) { + async fn run( + ep: Endpoint, + node_id: NodeId, + on_first_tx: oneshot::Sender>, + ) { let mut stream = match Self::create_stream(&ep, node_id) { Ok(stream) => stream, Err(err) => { @@ -500,8 +558,7 @@ impl DiscoveryTask { } } if let Some(tx) = on_first_tx.take() { - let err = anyhow!("Discovery produced no results for {}", node_id.fmt_short()); - tx.send(Err(err)).ok(); + tx.send(Err(NoResultsSnafu { node_id }.build())).ok(); } } } @@ -514,9 +571,12 @@ impl DiscoveryTask { /// will return the oldest [`DiscoveryItem`] that is still retained. /// /// Includes the number of skipped messages. -#[derive(Debug, thiserror::Error)] -#[error("channel lagged by {0}")] -pub struct Lagged(pub u64); +#[derive(Debug, Snafu)] +#[snafu(display("channel lagged by {val}"))] +pub struct Lagged { + /// The number of skipped messages + pub val: u64, +} #[derive(Clone, Debug)] pub(super) struct DiscoverySubscribers { @@ -537,7 +597,7 @@ impl DiscoverySubscribers { pub(crate) fn subscribe(&self) -> impl Stream> { use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; let recv = self.inner.subscribe(); - BroadcastStream::new(recv).map_err(|BroadcastStreamRecvError::Lagged(n)| Lagged(n)) + BroadcastStream::new(recv).map_err(|BroadcastStreamRecvError::Lagged(n)| Lagged { val: n }) } pub(crate) fn send(&self, item: DiscoveryItem) { @@ -556,11 +616,10 @@ mod tests { time::SystemTime, }; - use anyhow::Context; use iroh_base::{NodeAddr, SecretKey}; + use n0_snafu::{Error, Result, ResultExt}; use quinn::{IdleTimeout, TransportConfig}; use rand::Rng; - use testresult::TestResult; use tokio_util::task::AbortOnDropHandle; use tracing_test::traced_test; @@ -636,7 +695,7 @@ mod tests { &self, endpoint: Endpoint, node_id: NodeId, - ) -> Option>> { + ) -> Option>> { let addr_info = if self.resolve_wrong { let ts = system_time_now() - 100_000; let port: u16 = rand::thread_rng().gen_range(10_000..20_000); @@ -688,7 +747,7 @@ mod tests { &self, _endpoint: Endpoint, _node_id: NodeId, - ) -> Option>> { + ) -> Option>> { Some(n0_future::stream::empty().boxed()) } } @@ -698,7 +757,7 @@ mod tests { /// This is a smoke test for our discovery mechanism. #[tokio::test] #[traced_test] - async fn endpoint_discovery_simple_shared() -> anyhow::Result<()> { + async fn endpoint_discovery_simple_shared() -> Result { let disco_shared = TestDiscoveryShared::default(); let (ep1, _guard1) = { let secret = SecretKey::generate(rand::thread_rng()); @@ -720,7 +779,7 @@ mod tests { /// This test adds an empty discovery which provides no addresses. #[tokio::test] #[traced_test] - async fn endpoint_discovery_combined_with_empty() -> anyhow::Result<()> { + async fn endpoint_discovery_combined_with_empty() -> Result { let disco_shared = TestDiscoveryShared::default(); let (ep1, _guard1) = { let secret = SecretKey::generate(rand::thread_rng()); @@ -754,7 +813,7 @@ mod tests { /// will connect successfully. #[tokio::test] #[traced_test] - async fn endpoint_discovery_combined_with_empty_and_wrong() -> anyhow::Result<()> { + async fn endpoint_discovery_combined_with_empty_and_wrong() -> Result { let disco_shared = TestDiscoveryShared::default(); let (ep1, _guard1) = { let secret = SecretKey::generate(rand::thread_rng()); @@ -781,7 +840,7 @@ mod tests { /// This test only has the "lying" discovery. It is here to make sure that this actually fails. #[tokio::test] #[traced_test] - async fn endpoint_discovery_combined_wrong_only() -> anyhow::Result<()> { + async fn endpoint_discovery_combined_wrong_only() -> Result { let disco_shared = TestDiscoveryShared::default(); let (ep1, _guard1) = { let secret = SecretKey::generate(rand::thread_rng()); @@ -800,7 +859,7 @@ mod tests { // 10x faster test via a 3s idle timeout instead of the 30s default let mut config = TransportConfig::default(); config.keep_alive_interval(Some(Duration::from_secs(1))); - config.max_idle_timeout(Some(IdleTimeout::try_from(Duration::from_secs(3))?)); + config.max_idle_timeout(Some(IdleTimeout::try_from(Duration::from_secs(3)).unwrap())); let opts = ConnectOptions::new().with_transport_config(Arc::new(config)); let res = ep2 @@ -815,7 +874,7 @@ mod tests { /// Connect should still succeed because the discovery service will be invoked (after a delay). #[tokio::test] #[traced_test] - async fn endpoint_discovery_with_wrong_existing_addr() -> anyhow::Result<()> { + async fn endpoint_discovery_with_wrong_existing_addr() -> Result { let disco_shared = TestDiscoveryShared::default(); let (ep1, _guard1) = { let secret = SecretKey::generate(rand::thread_rng()); @@ -840,7 +899,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn endpoint_discovery_watch() -> anyhow::Result<()> { + async fn endpoint_discovery_watch() -> Result { let disco_shared = TestDiscoveryShared::default(); let (ep1, _guard1) = { let secret = SecretKey::generate(rand::thread_rng()); @@ -887,7 +946,7 @@ mod tests { async fn new_endpoint( secret: SecretKey, disco: impl Discovery + 'static, - ) -> (Endpoint, AbortOnDropHandle>) { + ) -> (Endpoint, AbortOnDropHandle>) { let ep = Endpoint::builder() .secret_key(secret) .discovery(Box::new(disco)) @@ -905,11 +964,11 @@ mod tests { // we skip accept() errors, they can be caused by retransmits while let Some(connecting) = ep.accept().await.and_then(|inc| inc.accept().ok()) { // Just accept incoming connections, but don't do anything with them. - let conn = connecting.await?; + let conn = connecting.await.context("connecting")?; connections.push(conn); } - anyhow::Ok(()) + Ok::<_, Error>(()) } }); @@ -924,7 +983,7 @@ mod tests { } #[tokio::test] - async fn test_arc_discovery() -> TestResult { + async fn test_arc_discovery() -> Result { let discovery = Arc::new(EmptyDiscovery); let _ep = Endpoint::builder() @@ -945,10 +1004,10 @@ mod tests { /// publish to. The DNS and pkarr servers share their state. #[cfg(test)] mod test_dns_pkarr { - use anyhow::Result; use iroh_base::{NodeAddr, SecretKey}; use iroh_relay::{node_info::UserData, RelayMap}; use n0_future::time::Duration; + use n0_snafu::{Error, Result, ResultExt}; use tokio_util::task::AbortOnDropHandle; use tracing_test::traced_test; @@ -969,13 +1028,17 @@ mod test_dns_pkarr { async fn dns_resolve() -> Result<()> { let origin = "testdns.example".to_string(); let state = State::new(origin.clone()); - let (nameserver, _dns_drop_guard) = run_dns_server(state.clone()).await?; + let (nameserver, _dns_drop_guard) = run_dns_server(state.clone()) + .await + .context("Running DNS server")?; let secret_key = SecretKey::generate(rand::thread_rng()); let node_info = NodeInfo::new(secret_key.public()) .with_relay_url(Some("https://relay.example".parse().unwrap())); let signed_packet = node_info.to_pkarr_signed_packet(&secret_key, 30)?; - state.upsert(signed_packet)?; + state + .upsert(signed_packet) + .context("update and insert signed packet")?; let resolver = DnsResolver::with_nameserver(nameserver); let resolved = resolver @@ -992,7 +1055,9 @@ mod test_dns_pkarr { async fn pkarr_publish_dns_resolve() -> Result<()> { let origin = "testdns.example".to_string(); - let dns_pkarr_server = DnsPkarrServer::run_with_origin(origin.clone()).await?; + let dns_pkarr_server = DnsPkarrServer::run_with_origin(origin.clone()) + .await + .context("DnsPkarrServer")?; let secret_key = SecretKey::generate(rand::thread_rng()); let node_id = secret_key.public(); @@ -1007,7 +1072,10 @@ mod test_dns_pkarr { // does not block, update happens in background task publisher.update_node_data(&data); // wait until our shared state received the update from pkarr publishing - dns_pkarr_server.on_node(&node_id, PUBLISH_TIMEOUT).await?; + dns_pkarr_server + .on_node(&node_id, PUBLISH_TIMEOUT) + .await + .context("wait for on node update")?; let resolved = resolver.lookup_node_by_id(&node_id, &origin).await?; println!("resolved {resolved:?}"); @@ -1027,7 +1095,7 @@ mod test_dns_pkarr { #[tokio::test] #[traced_test] async fn pkarr_publish_dns_discover() -> Result<()> { - let dns_pkarr_server = DnsPkarrServer::run().await?; + let dns_pkarr_server = DnsPkarrServer::run().await.context("DnsPkarrServer run")?; let (relay_map, _relay_url, _relay_guard) = run_relay_server().await?; let (ep1, _guard1) = ep_with_discovery(&relay_map, &dns_pkarr_server).await?; @@ -1036,7 +1104,8 @@ mod test_dns_pkarr { // wait until our shared state received the update from pkarr publishing dns_pkarr_server .on_node(&ep1.node_id(), PUBLISH_TIMEOUT) - .await?; + .await + .context("wait for on node update")?; // we connect only by node id! let res = ep2.connect(ep1.node_id(), TEST_ALPN).await; @@ -1064,11 +1133,11 @@ mod test_dns_pkarr { async move { // we skip accept() errors, they can be caused by retransmits while let Some(connecting) = ep.accept().await.and_then(|inc| inc.accept().ok()) { - let _conn = connecting.await?; + let _conn = connecting.await.context("connecting")?; // Just accept incoming connections, but don't do anything with them. } - anyhow::Ok(()) + Ok::<_, Error>(()) } }); diff --git a/iroh/src/discovery/dns.rs b/iroh/src/discovery/dns.rs index 6de92cadbd3..950160af229 100644 --- a/iroh/src/discovery/dns.rs +++ b/iroh/src/discovery/dns.rs @@ -1,10 +1,10 @@ //! DNS node discovery for iroh -use anyhow::Result; use iroh_base::NodeId; pub use iroh_relay::dns::{N0_DNS_NODE_ORIGIN_PROD, N0_DNS_NODE_ORIGIN_STAGING}; use n0_future::boxed::BoxStream; +use super::DiscoveryError; use crate::{ discovery::{Discovery, DiscoveryItem}, endpoint::force_staging_infra, @@ -63,13 +63,18 @@ impl DnsDiscovery { } impl Discovery for DnsDiscovery { - fn resolve(&self, ep: Endpoint, node_id: NodeId) -> Option>> { + fn resolve( + &self, + ep: Endpoint, + node_id: NodeId, + ) -> Option>> { let resolver = ep.dns_resolver().clone(); let origin_domain = self.origin_domain.clone(); let fut = async move { let node_info = resolver .lookup_node_by_id_staggered(&node_id, &origin_domain, DNS_STAGGERING_MS) - .await?; + .await + .map_err(|e| DiscoveryError::from_err("dns", e))?; Ok(DiscoveryItem::new(node_info, "dns", None)) }; let stream = n0_future::stream::once_future(fut); diff --git a/iroh/src/discovery/mdns.rs b/iroh/src/discovery/mdns.rs index 38f3079952f..7836f0dba14 100644 --- a/iroh/src/discovery/mdns.rs +++ b/iroh/src/discovery/mdns.rs @@ -35,7 +35,6 @@ use std::{ net::{IpAddr, SocketAddr}, }; -use anyhow::Result; use derive_more::FromStr; use iroh_base::{NodeId, PublicKey}; use n0_future::{ @@ -47,6 +46,7 @@ use swarm_discovery::{Discoverer, DropGuard, IpClass, Peer}; use tokio::sync::mpsc::{self, error::TrySendError}; use tracing::{debug, error, info_span, trace, warn, Instrument}; +use super::DiscoveryError; use crate::{ discovery::{Discovery, DiscoveryItem, NodeData, NodeInfo}, watcher::{Watchable, Watcher as _}, @@ -83,7 +83,7 @@ pub struct MdnsDiscovery { #[derive(Debug)] enum Message { Discovery(String, Peer), - Resolve(NodeId, mpsc::Sender>), + Resolve(NodeId, mpsc::Sender>), Timeout(NodeId, usize), Subscribe(mpsc::Sender), } @@ -138,7 +138,7 @@ impl MdnsDiscovery { /// /// # Panics /// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime. - pub fn new(node_id: NodeId) -> Result { + pub fn new(node_id: NodeId) -> Result { debug!("Creating new MdnsDiscovery service"); let (send, mut recv) = mpsc::channel(64); let task_sender = send.clone(); @@ -154,7 +154,7 @@ impl MdnsDiscovery { let mut last_id = 0; let mut senders: HashMap< PublicKey, - HashMap>>, + HashMap>>, > = HashMap::default(); let mut timeouts = JoinSet::new(); loop { @@ -306,7 +306,7 @@ impl MdnsDiscovery { sender: mpsc::Sender, socketaddrs: BTreeSet, rt: &tokio::runtime::Handle, - ) -> Result { + ) -> Result { let spawn_rt = rt.clone(); let callback = move |node_id: &str, peer: &Peer| { trace!( @@ -332,7 +332,9 @@ impl MdnsDiscovery { for addr in addrs { discoverer = discoverer.with_addrs(addr.0, addr.1); } - discoverer.spawn(rt) + discoverer + .spawn(rt) + .map_err(|e| DiscoveryError::from_err_box("mdns", e.into_boxed_dyn_error())) } fn socketaddrs_to_addrs(socketaddrs: &BTreeSet) -> HashMap> { @@ -373,7 +375,11 @@ fn peer_to_discovery_item(peer: &Peer, node_id: &NodeId) -> DiscoveryItem { } impl Discovery for MdnsDiscovery { - fn resolve(&self, _ep: Endpoint, node_id: NodeId) -> Option>> { + fn resolve( + &self, + _ep: Endpoint, + node_id: NodeId, + ) -> Option>> { use futures_util::FutureExt; let (send, recv) = mpsc::channel(20); @@ -413,7 +419,8 @@ mod tests { mod run_in_isolation { use iroh_base::SecretKey; use n0_future::StreamExt; - use testresult::TestResult; + use n0_snafu::{Error, Result, ResultExt}; + use snafu::whatever; use tracing_test::traced_test; use super::super::*; @@ -421,13 +428,13 @@ mod tests { #[tokio::test] #[traced_test] - async fn mdns_publish_resolve() -> TestResult { + async fn mdns_publish_resolve() -> Result { let (_, discovery_a) = make_discoverer()?; let (node_id_b, discovery_b) = make_discoverer()?; // make addr info for discoverer b let user_data: UserData = "foobar".parse()?; - let node_data = NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse()?])) + let node_data = NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()])) .with_user_data(Some(user_data.clone())); println!("info {node_data:?}"); @@ -442,10 +449,12 @@ mod tests { // publish discovery_b's address discovery_b.publish(&node_data); let s1_res = tokio::time::timeout(Duration::from_secs(5), s1.next()) - .await? + .await + .context("timeout")? .unwrap()?; let s2_res = tokio::time::timeout(Duration::from_secs(5), s2.next()) - .await? + .await + .context("timeout")? .unwrap()?; assert_eq!(s1_res.node_info().data, node_data); assert_eq!(s2_res.node_info().data, node_data); @@ -455,13 +464,13 @@ mod tests { #[tokio::test] #[traced_test] - async fn mdns_subscribe() -> TestResult { + async fn mdns_subscribe() -> Result { let num_nodes = 5; let mut node_ids = BTreeSet::new(); let mut discoverers = vec![]; let (_, discovery) = make_discoverer()?; - let node_data = NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse()?])); + let node_data = NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()])); for i in 0..num_nodes { let (node_id, discovery) = make_discoverer()?; @@ -482,17 +491,18 @@ mod tests { got_ids.insert((item.node_id(), item.user_data())); } } else { - anyhow::bail!( + whatever!( "no more events, only got {} ids, expected {num_nodes}\n", got_ids.len() ); } } assert_eq!(got_ids, node_ids); - anyhow::Ok(()) + Ok::<_, Error>(()) }; - tokio::time::timeout(Duration::from_secs(5), test).await??; - Ok(()) + tokio::time::timeout(Duration::from_secs(5), test) + .await + .context("timeout")? } fn make_discoverer() -> Result<(PublicKey, MdnsDiscovery)> { diff --git a/iroh/src/discovery/pkarr.rs b/iroh/src/discovery/pkarr.rs index be3f8d958df..a014e62b10e 100644 --- a/iroh/src/discovery/pkarr.rs +++ b/iroh/src/discovery/pkarr.rs @@ -46,20 +46,24 @@ use std::sync::Arc; -use anyhow::{anyhow, bail, Result}; -use iroh_base::{NodeId, SecretKey}; +use iroh_base::{NodeId, RelayUrl, SecretKey}; use iroh_relay::node_info::NodeInfo; use n0_future::{ boxed::BoxStream, task::{self, AbortOnDropHandle}, time::{self, Duration, Instant}, }; -use pkarr::SignedPacket; +use pkarr::{ + errors::{PublicKeyError, SignedPacketVerifyError}, + SignedPacket, +}; +use snafu::{ResultExt, Snafu}; use tracing::{debug, error_span, warn, Instrument}; use url::Url; +use super::DiscoveryError; use crate::{ - discovery::{Discovery, DiscoveryItem, NodeData}, + discovery::{Discovery, DiscoveryItem, NodeData, ParsePacketSnafu, SignedPacketSnafu}, endpoint::force_staging_infra, watcher::{self, Disconnected, Watchable, Watcher as _}, Endpoint, @@ -68,6 +72,30 @@ use crate::{ #[cfg(feature = "discovery-pkarr-dht")] pub mod dht; +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum PkarrError { + #[snafu(display("Invalid public key"))] + PublicKey { source: PublicKeyError }, + #[snafu(display("Packet failed to verify"))] + Verify { source: SignedPacketVerifyError }, + #[snafu(display("Invalid relay URL"))] + InvalidRelayUrl { url: RelayUrl }, + #[snafu(display("Error sending http request"))] + HttpSend { source: reqwest::Error }, + #[snafu(display("Error resolving http request"))] + HttpRequest { status: reqwest::StatusCode }, + #[snafu(display("Http payload error"))] + HttpPayload { source: reqwest::Error }, +} + +impl From for DiscoveryError { + fn from(err: PkarrError) -> Self { + DiscoveryError::from_err("pkarr", err) + } +} + /// The production pkarr relay run by [number 0]. /// /// This server is both a pkarr relay server as well as a DNS resolver, see the [module @@ -256,20 +284,22 @@ impl PublisherService { tokio::select! { res = self.watcher.updated() => match res { Ok(_) => debug!("Publish node info to pkarr (info changed)"), - Err(Disconnected) => break, + Err(Disconnected { .. }) => break, }, _ = &mut republish => debug!("Publish node info to pkarr (interval elapsed)"), } } } - async fn publish_current(&self, info: NodeInfo) -> Result<()> { + async fn publish_current(&self, info: NodeInfo) -> Result<(), DiscoveryError> { debug!( data = ?info.data, pkarr_relay = %self.pkarr_client.pkarr_relay_url, "Publish node info to pkarr" ); - let signed_packet = info.to_pkarr_signed_packet(&self.secret_key, self.ttl)?; + let signed_packet = info + .to_pkarr_signed_packet(&self.secret_key, self.ttl) + .context(SignedPacketSnafu)?; self.pkarr_client.publish(&signed_packet).await?; Ok(()) } @@ -322,11 +352,16 @@ impl PkarrResolver { } impl Discovery for PkarrResolver { - fn resolve(&self, _ep: Endpoint, node_id: NodeId) -> Option>> { + fn resolve( + &self, + _ep: Endpoint, + node_id: NodeId, + ) -> Option>> { let pkarr_client = self.pkarr_client.clone(); let fut = async move { let signed_packet = pkarr_client.resolve(node_id).await?; - let info = NodeInfo::from_pkarr_signed_packet(&signed_packet)?; + let info = + NodeInfo::from_pkarr_signed_packet(&signed_packet).context(ParsePacketSnafu)?; let item = DiscoveryItem::new(info, "pkarr", None); Ok(item) }; @@ -354,35 +389,52 @@ impl PkarrRelayClient { } /// Resolves a [`SignedPacket`] for the given [`NodeId`]. - pub async fn resolve(&self, node_id: NodeId) -> anyhow::Result { + pub async fn resolve(&self, node_id: NodeId) -> Result { // We map the error to string, as in browsers the error is !Send - let public_key = pkarr::PublicKey::try_from(node_id.as_bytes()) - .map_err(|e| anyhow::anyhow!(e.to_string()))?; + let public_key = pkarr::PublicKey::try_from(node_id.as_bytes()).context(PublicKeySnafu)?; + let mut url = self.pkarr_relay_url.clone(); url.path_segments_mut() - .map_err(|_| anyhow!("Failed to resolve: Invalid relay URL"))? + .map_err(|_| { + InvalidRelayUrlSnafu { + url: self.pkarr_relay_url.clone(), + } + .build() + })? .push(&public_key.to_z32()); - let response = self.http_client.get(url).send().await?; + let response = self + .http_client + .get(url) + .send() + .await + .context(HttpSendSnafu)?; if !response.status().is_success() { - bail!(format!( - "Resolve request failed with status {}", - response.status() - )) + return Err(HttpRequestSnafu { + status: response.status(), + } + .build() + .into()); } - let payload = response.bytes().await?; + let payload = response.bytes().await.context(HttpPayloadSnafu)?; // We map the error to string, as in browsers the error is !Send - SignedPacket::from_relay_payload(&public_key, &payload) - .map_err(|e| anyhow::anyhow!(e.to_string())) + let packet = + SignedPacket::from_relay_payload(&public_key, &payload).context(VerifySnafu)?; + Ok(packet) } /// Publishes a [`SignedPacket`]. - pub async fn publish(&self, signed_packet: &SignedPacket) -> anyhow::Result<()> { + pub async fn publish(&self, signed_packet: &SignedPacket) -> Result<(), DiscoveryError> { let mut url = self.pkarr_relay_url.clone(); url.path_segments_mut() - .map_err(|_| anyhow!("Failed to publish: Invalid relay URL"))? + .map_err(|_| { + InvalidRelayUrlSnafu { + url: self.pkarr_relay_url.clone(), + } + .build() + })? .push(&signed_packet.public_key().to_z32()); let response = self @@ -390,18 +442,15 @@ impl PkarrRelayClient { .put(url) .body(signed_packet.to_relay_payload()) .send() - .await?; - - let status = response.status(); - if !status.is_success() { - let text = if let Ok(text) = response.text().await { - text - } else { - "(no details provided)".to_string() - }; - bail!(format!( - "Publish request failed with status {status}: {text}", - )) + .await + .context(HttpSendSnafu)?; + + if !response.status().is_success() { + return Err(HttpRequestSnafu { + status: response.status(), + } + .build() + .into()); } Ok(()) diff --git a/iroh/src/discovery/pkarr/dht.rs b/iroh/src/discovery/pkarr/dht.rs index 3164892566e..2c3d27d075b 100644 --- a/iroh/src/discovery/pkarr/dht.rs +++ b/iroh/src/discovery/pkarr/dht.rs @@ -7,7 +7,6 @@ //! [pkarr module]: super use std::sync::{Arc, Mutex}; -use anyhow::Result; use iroh_base::{NodeId, SecretKey}; use n0_future::{ boxed::BoxStream, @@ -21,7 +20,7 @@ use url::Url; use crate::{ discovery::{ pkarr::{DEFAULT_PKARR_TTL, N0_DNS_PKARR_RELAY_PROD}, - Discovery, DiscoveryItem, NodeData, + Discovery, DiscoveryError, DiscoveryItem, NodeData, }, node_info::NodeInfo, Endpoint, @@ -81,7 +80,10 @@ struct Inner { } impl Inner { - async fn resolve_pkarr(&self, key: pkarr::PublicKey) -> Option> { + async fn resolve_pkarr( + &self, + key: pkarr::PublicKey, + ) -> Option> { tracing::info!( "resolving {} from relay and DHT {:?}", key.to_z32(), @@ -198,11 +200,13 @@ impl Builder { } /// Builds the discovery mechanism. - pub fn build(self) -> Result { - anyhow::ensure!( - self.dht || self.pkarr_relay.is_some(), - "at least one of DHT or relay must be enabled" - ); + pub fn build(self) -> Result { + if !(self.dht || self.pkarr_relay.is_some()) { + return Err(DiscoveryError::from_err( + "pkarr", + std::io::Error::other("at least one of DHT or relay must be enabled"), + )); + } let pkarr = match self.client { Some(client) => client, None => { @@ -212,9 +216,13 @@ impl Builder { builder.dht(|x| x); } if let Some(url) = &self.pkarr_relay { - builder.relays(&[url.clone()])?; + builder + .relays(&[url.clone()]) + .map_err(|e| DiscoveryError::from_err("pkarr", e))?; } - builder.build()? + builder + .build() + .map_err(|e| DiscoveryError::from_err("pkarr", e))? } }; let ttl = self.ttl.unwrap_or(DEFAULT_PKARR_TTL); @@ -295,7 +303,7 @@ impl Discovery for DhtDiscovery { &self, _endpoint: Endpoint, node_id: NodeId, - ) -> Option>> { + ) -> Option>> { let pkarr_public_key = pkarr::PublicKey::try_from(node_id.as_bytes()).expect("valid public key"); tracing::info!("resolving {} as {}", node_id, pkarr_public_key.to_z32()); @@ -314,7 +322,7 @@ mod tests { use std::collections::BTreeSet; use iroh_base::RelayUrl; - use testresult::TestResult; + use n0_snafu::{Result, ResultExt}; use tracing_test::traced_test; use super::*; @@ -322,20 +330,21 @@ mod tests { #[tokio::test] #[ignore = "flaky"] #[traced_test] - async fn dht_discovery_smoke() -> TestResult { + async fn dht_discovery_smoke() -> Result { let ep = crate::Endpoint::builder().bind().await?; let secret = ep.secret_key().clone(); - let testnet = pkarr::mainline::Testnet::new_async(3).await?; + let testnet = pkarr::mainline::Testnet::new_async(3).await.e()?; let client = pkarr::Client::builder() .dht(|builder| builder.bootstrap(&testnet.bootstrap)) - .build()?; + .build() + .e()?; let discovery = DhtDiscovery::builder() .secret_key(secret.clone()) .initial_publish_delay(Duration::ZERO) .client(client) .build()?; - let relay_url: RelayUrl = Url::parse("https://example.com")?.into(); + let relay_url: RelayUrl = Url::parse("https://example.com").e()?.into(); let data = NodeData::default().with_relay_url(Some(relay_url.clone())); discovery.publish(&data); diff --git a/iroh/src/discovery/static_provider.rs b/iroh/src/discovery/static_provider.rs index 047cb7cc630..8b6392bd473 100644 --- a/iroh/src/discovery/static_provider.rs +++ b/iroh/src/discovery/static_provider.rs @@ -41,7 +41,7 @@ use super::{Discovery, DiscoveryItem, NodeData, NodeInfo}; /// use iroh_base::SecretKey; /// /// # #[tokio::main] -/// # async fn main() -> anyhow::Result<()> { +/// # async fn main() -> n0_snafu::Result<()> { /// // Create the discovery service and endpoint. /// let discovery = StaticProvider::new(); /// @@ -108,7 +108,7 @@ impl StaticProvider { /// # Vec::new() /// # } /// # #[tokio::main] - /// # async fn main() -> anyhow::Result<()> { + /// # async fn main() -> n0_snafu::Result<()> { /// // get addrs from somewhere /// let addrs = get_addrs(); /// @@ -193,7 +193,7 @@ impl Discovery for StaticProvider { &self, _endpoint: crate::Endpoint, node_id: NodeId, - ) -> Option>> { + ) -> Option>> { let guard = self.nodes.read().expect("poisoned"); let info = guard.get(&node_id); match info { @@ -217,15 +217,14 @@ impl Discovery for StaticProvider { #[cfg(test)] mod tests { - use anyhow::Context; use iroh_base::{NodeAddr, SecretKey}; - use testresult::TestResult; + use n0_snafu::{Result, ResultExt}; use super::*; use crate::Endpoint; #[tokio::test] - async fn test_basic() -> TestResult { + async fn test_basic() -> Result { let discovery = StaticProvider::new(); let _ep = Endpoint::builder() diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index 8df217d86ea..241666787e0 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -21,12 +21,13 @@ use std::{ task::Poll, }; -use anyhow::{bail, Context, Result}; use ed25519_dalek::{pkcs8::DecodePublicKey, VerifyingKey}; use iroh_base::{NodeAddr, NodeId, RelayUrl, SecretKey}; use iroh_relay::RelayMap; use n0_future::{time::Duration, Stream}; +use nested_enum_utils::common_fields; use pin_project::pin_project; +use snafu::{ensure, ResultExt, Snafu}; use tracing::{debug, instrument, trace, warn}; use url::Url; @@ -36,10 +37,10 @@ use crate::discovery::pkarr::PkarrResolver; use crate::{discovery::dns::DnsDiscovery, dns::DnsResolver}; use crate::{ discovery::{ - pkarr::PkarrPublisher, ConcurrentDiscovery, Discovery, DiscoveryItem, DiscoverySubscribers, - DiscoveryTask, Lagged, UserData, + pkarr::PkarrPublisher, ConcurrentDiscovery, Discovery, DiscoveryError, DiscoveryItem, + DiscoverySubscribers, DiscoveryTask, Lagged, UserData, }, - magicsock::{self, Handle, NodeIdMappedAddr}, + magicsock::{self, Handle, NodeIdMappedAddr, OwnAddressSnafu}, metrics::EndpointMetrics, net_report::Report, tls, @@ -68,7 +69,8 @@ pub use quinn_proto::{ use self::rtt_actor::RttMessage; pub use super::magicsock::{ - ConnectionType, ControlMsg, DirectAddr, DirectAddrInfo, DirectAddrType, RemoteInfo, Source, + AddNodeAddrError, ConnectionType, ControlMsg, DirectAddr, DirectAddrInfo, DirectAddrType, + RemoteInfo, Source, }; /// The delay to fall back to discovery when direct addresses fail. @@ -181,7 +183,7 @@ impl Builder { // # The final constructor that everyone needs. /// Binds the magic endpoint. - pub async fn bind(self) -> Result { + pub async fn bind(self) -> Result { let relay_map = self.relay_mode.relay_map(); let secret_key = self .secret_key @@ -609,6 +611,79 @@ pub struct Endpoint { static_config: Arc, } +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ConnectWithOptsError { + #[snafu(transparent)] + AddNodeAddr { source: AddNodeAddrError }, + #[snafu(display("Connecting to ourself is not supported"))] + SelfConnect {}, + #[snafu(display("No addressing information available"))] + NoAddress { source: GetMappingAddressError }, + #[snafu(display("Unable to connect to remote"))] + Quinn { source: quinn::ConnectError }, +} + +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ConnectError { + #[snafu(transparent)] + Connect { + #[snafu(source(from(ConnectWithOptsError, Box::new)))] + source: Box, + }, + #[snafu(transparent)] + Connection { + #[snafu(source(from(ConnectionError, Box::new)))] + source: Box, + }, +} + +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum BindError { + #[snafu(transparent)] + MagicSpawn { + source: magicsock::CreateHandleError, + }, +} + +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +pub enum GetMappingAddressError { + #[snafu(display("Discovery service required due to missing addressing information"))] + DiscoveryStart { source: DiscoveryError }, + #[snafu(display("Discovery service failed"))] + Discover { source: DiscoveryError }, + #[snafu(display("No addressing information found"))] + NoAddress {}, +} + impl Endpoint { // The ordering of public methods is reflected directly in the documentation. This is // roughly ordered by what is most commonly needed by users, but grouped in similar @@ -626,7 +701,10 @@ impl Endpoint { /// This is for internal use, the public interface is the [`Builder`] obtained from /// [Self::builder]. See the methods on the builder for documentation of the parameters. #[instrument("ep", skip_all, fields(me = %static_config.tls_config.secret_key.public().fmt_short()))] - async fn bind(static_config: StaticConfig, msock_opts: magicsock::Options) -> Result { + async fn bind( + static_config: StaticConfig, + msock_opts: magicsock::Options, + ) -> Result { let msock = magicsock::MagicSock::spawn(msock_opts).await?; trace!("created magicsock"); debug!(version = env!("CARGO_PKG_VERSION"), "iroh Endpoint created"); @@ -672,15 +750,18 @@ impl Endpoint { /// The `alpn`, or application-level protocol identifier, is also required. The remote /// endpoint must support this `alpn`, otherwise the connection attempt will fail with /// an error. - pub async fn connect(&self, node_addr: impl Into, alpn: &[u8]) -> Result { + pub async fn connect( + &self, + node_addr: impl Into, + alpn: &[u8], + ) -> Result { let node_addr = node_addr.into(); let remote = node_addr.node_id; let connecting = self .connect_with_opts(node_addr, alpn, Default::default()) .await?; - let conn = connecting - .await - .context("failed connecting to remote endpoint")?; + let conn = connecting.await?; + debug!( me = %self.node_id().fmt_short(), remote = %remote.fmt_short(), @@ -714,17 +795,12 @@ impl Endpoint { node_addr: impl Into, alpn: &[u8], options: ConnectOptions, - ) -> Result { + ) -> Result { let node_addr: NodeAddr = node_addr.into(); tracing::Span::current().record("remote", node_addr.node_id.fmt_short()); // Connecting to ourselves is not supported. - if node_addr.node_id == self.node_id() { - bail!( - "Connecting to ourself is not supported ({} is the node id of this node)", - node_addr.node_id.fmt_short() - ); - } + ensure!(node_addr.node_id != self.node_id(), SelfConnectSnafu); if !node_addr.is_empty() { self.add_node_addr(node_addr.clone())?; @@ -740,12 +816,7 @@ impl Endpoint { let (mapped_addr, _discovery_drop_guard) = self .get_mapping_addr_and_maybe_start_discovery(node_addr) .await - .with_context(|| { - format!( - "No addressing information for NodeId({}), unable to connect", - node_id.fmt_short() - ) - })?; + .context(NoAddressSnafu)?; let transport_config = options .transport_config @@ -773,11 +844,15 @@ impl Endpoint { }; let server_name = &tls::name::encode(node_id); - let connect = self.msock.endpoint().connect_with( - client_config, - mapped_addr.private_socket_addr(), - server_name, - )?; + let connect = self + .msock + .endpoint() + .connect_with( + client_config, + mapped_addr.private_socket_addr(), + server_name, + ) + .context(QuinnSnafu)?; Ok(Connecting { inner: connect, @@ -825,7 +900,7 @@ impl Endpoint { /// if the direct addresses are a subset of ours. /// /// [`StaticProvider`]: crate::discovery::static_provider::StaticProvider - pub fn add_node_addr(&self, node_addr: NodeAddr) -> Result<()> { + pub fn add_node_addr(&self, node_addr: NodeAddr) -> Result<(), AddNodeAddrError> { self.add_node_addr_inner(node_addr, magicsock::Source::App) } @@ -852,7 +927,7 @@ impl Endpoint { &self, node_addr: NodeAddr, source: &'static str, - ) -> Result<()> { + ) -> Result<(), AddNodeAddrError> { self.add_node_addr_inner( node_addr, magicsock::Source::NamedApp { @@ -861,14 +936,13 @@ impl Endpoint { ) } - fn add_node_addr_inner(&self, node_addr: NodeAddr, source: magicsock::Source) -> Result<()> { + fn add_node_addr_inner( + &self, + node_addr: NodeAddr, + source: magicsock::Source, + ) -> Result<(), AddNodeAddrError> { // Connecting to ourselves is not supported. - if node_addr.node_id == self.node_id() { - bail!( - "Adding our own address is not supported ({} is the node id of this node)", - node_addr.node_id.fmt_short() - ); - } + snafu::ensure!(node_addr.node_id != self.node_id(), OwnAddressSnafu); self.msock.add_node_addr(node_addr, source) } @@ -895,7 +969,7 @@ impl Endpoint { /// Use [`Watcher::initialized`] to wait for a [`NodeAddr`] that is ready to be connected to: /// /// ```no_run - /// # async fn wrapper() -> testresult::TestResult { + /// # async fn wrapper() -> n0_snafu::Result { /// use iroh::{watcher::Watcher, Endpoint}; /// /// let endpoint = Endpoint::builder() @@ -1147,8 +1221,8 @@ impl Endpoint { /// /// # Errors /// - /// Will error if we do not have any address information for the given `node_id`. - pub fn conn_type(&self, node_id: NodeId) -> Result> { + /// Will return `None` if we do not have any address information for the given `node_id`. + pub fn conn_type(&self, node_id: NodeId) -> Option> { self.msock.conn_type(node_id) } @@ -1176,7 +1250,7 @@ impl Endpoint { /// ```rust /// # use std::collections::BTreeMap; /// # use iroh::endpoint::Endpoint; - /// # async fn wrapper() -> testresult::TestResult { + /// # async fn wrapper() -> n0_snafu::Result { /// let endpoint = Endpoint::builder().bind().await?; /// assert_eq!(endpoint.metrics().magicsock.recv_datagrams.get(), 0); /// # Ok(()) @@ -1194,7 +1268,7 @@ impl Endpoint { /// # use std::collections::BTreeMap; /// # use iroh_metrics::{Metric, MetricsGroup, MetricValue, MetricsGroupSet}; /// # use iroh::endpoint::Endpoint; - /// # async fn wrapper() -> testresult::TestResult { + /// # async fn wrapper() -> n0_snafu::Result { /// let endpoint = Endpoint::builder().bind().await?; /// let metrics: BTreeMap = endpoint /// .metrics() @@ -1217,7 +1291,7 @@ impl Endpoint { /// ```rust /// # use iroh_metrics::{Registry, MetricsSource}; /// # use iroh::endpoint::Endpoint; - /// # async fn wrapper() -> testresult::TestResult { + /// # async fn wrapper() -> n0_snafu::Result { /// let endpoint = Endpoint::builder().bind().await?; /// let mut registry = Registry::default(); /// registry.register_all(endpoint.metrics()); @@ -1240,7 +1314,8 @@ impl Endpoint { /// # use std::{sync::{Arc, RwLock}, time::Duration}; /// # use iroh_metrics::{Registry, MetricsSource}; /// # use iroh::endpoint::Endpoint; - /// # async fn wrapper() -> testresult::TestResult { + /// # use n0_snafu::ResultExt; + /// # async fn wrapper() -> n0_snafu::Result { /// // Create a registry, wrapped in a read-write lock so that we can register and serve /// // the metrics independently. /// let registry = Arc::new(RwLock::new(Registry::default())); @@ -1260,9 +1335,11 @@ impl Endpoint { /// // Wait for the metrics server to bind, then fetch the metrics via HTTP. /// tokio::time::sleep(Duration::from_millis(500)); /// let res = reqwest::get("http://localhost:9100/metrics") - /// .await? + /// .await + /// .context("get")? /// .text() - /// .await?; + /// .await + /// .context("text")?; /// /// assert!(res.contains(r#"TYPE magicsock_recv_datagrams counter"#)); /// assert!(res.contains(r#"magicsock_recv_datagrams_total 0"#)); @@ -1378,7 +1455,7 @@ impl Endpoint { async fn get_mapping_addr_and_maybe_start_discovery( &self, node_addr: NodeAddr, - ) -> Result<(NodeIdMappedAddr, Option)> { + ) -> Result<(NodeIdMappedAddr, Option), GetMappingAddressError> { let node_id = node_addr.node_id; // Only return a mapped addr if we have some way of dialing this node, in other @@ -1409,16 +1486,16 @@ impl Endpoint { // So, we start a discovery task and wait for the first result to arrive, and // only then continue, because otherwise we wouldn't have any // path to the remote endpoint. - let mut discovery = DiscoveryTask::start(self.clone(), node_id) - .context("Discovery service required due to missing addressing information")?; + let res = DiscoveryTask::start(self.clone(), node_id); + let mut discovery = res.context(get_mapping_address_error::DiscoveryStartSnafu)?; discovery .first_arrived() .await - .context("Discovery service failed")?; + .context(get_mapping_address_error::DiscoverSnafu)?; if let Some(addr) = self.msock.get_mapping_addr(node_id) { Ok((addr, Some(discovery))) } else { - bail!("Discovery did not find addressing information"); + Err(get_mapping_address_error::NoAddressSnafu.build()) } } } @@ -1653,6 +1730,23 @@ pub struct Connecting { _discovery_drop_guard: Option, } +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum AlpnError { + #[snafu(transparent)] + ConnectionError { source: ConnectionError }, + #[snafu(display("No ALPN available"))] + Unavailable {}, + #[snafu(display("Unknown handshake type"))] + UnknownHandshake {}, +} + impl Connecting { /// Converts this [`Connecting`] into a 0-RTT or 0.5-RTT connection at the cost of weakened /// security. @@ -1735,16 +1829,14 @@ impl Connecting { } /// Extracts the ALPN protocol from the peer's handshake data. - // Note, we could totally provide this method to be on a Connection as well. But we'd - // need to wrap Connection too. - pub async fn alpn(&mut self) -> Result> { + pub async fn alpn(&mut self) -> Result, AlpnError> { let data = self.handshake_data().await?; match data.downcast::() { Ok(data) => match data.protocol { Some(protocol) => Ok(protocol), - None => bail!("no ALPN protocol available"), + None => Err(UnavailableSnafu.build()), }, - Err(_) => bail!("unknown handshake type"), + Err(_) => Err(UnknownHandshakeSnafu.build()), } } } @@ -1813,6 +1905,13 @@ pub struct Connection { tls_auth: tls::Authentication, } +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[snafu(display("Protocol error: no remote id available"))] +pub struct RemoteNodeIdError { + backtrace: Option, +} + impl Connection { /// Initiates a new outgoing unidirectional stream. /// @@ -2028,31 +2127,41 @@ impl Connection { /// /// [`PublicKey`]: iroh_base::PublicKey // TODO: Would be nice if this could be infallible. - pub fn remote_node_id(&self) -> Result { + pub fn remote_node_id(&self) -> Result { let data = self.peer_identity(); match data { - None => bail!("no peer certificate found"), + None => { + warn!("no peer certificate found"); + Err(RemoteNodeIdSnafu.build()) + } Some(data) => match data.downcast::>() { Ok(certs) => { if certs.len() != 1 { - bail!( + warn!( "expected a single peer certificate, but {} found", certs.len() ); + return Err(RemoteNodeIdSnafu.build()); } match self.tls_auth { tls::Authentication::X509 => { - let cert = tls::certificate::parse(&certs[0])?; + let cert = tls::certificate::parse(&certs[0]) + .map_err(|_| RemoteNodeIdSnafu.build())?; Ok(cert.peer_id()) } tls::Authentication::RawPublicKey => { - let peer_id = VerifyingKey::from_public_key_der(&certs[0])?.into(); + let peer_id = VerifyingKey::from_public_key_der(&certs[0]) + .map_err(|_| RemoteNodeIdSnafu.build())? + .into(); Ok(peer_id) } } } - Err(_) => bail!("invalid peer certificate"), + Err(err) => { + warn!("invalid peer certificate: {:?}", err); + Err(RemoteNodeIdSnafu.build()) + } }, } } @@ -2119,7 +2228,7 @@ fn try_send_rtt_msg(conn: &Connection, magic_ep: &Endpoint, remote_node_id: Opti warn!(?conn, "failed to get remote node id"); return; }; - let Ok(conn_type_changes) = magic_ep.conn_type(node_id) else { + let Some(conn_type_changes) = magic_ep.conn_type(node_id) else { warn!(?conn, "failed to create conn_type stream"); return; }; @@ -2231,25 +2340,34 @@ fn is_cgi() -> bool { // https://github.com/n0-computer/iroh/issues/1183 #[cfg(test)] mod tests { + use std::{ + net::SocketAddr, + time::{Duration, Instant}, + }; - use std::time::Instant; - - use iroh_metrics::MetricsSource; + use iroh_base::{NodeAddr, NodeId, SecretKey}; use iroh_relay::http::Protocol; use n0_future::{task::AbortOnDropHandle, StreamExt}; + use n0_snafu::{Error, Result, ResultExt}; + use quinn::ConnectionError; use rand::SeedableRng; - use testresult::TestResult; use tracing::{error_span, info, info_span, Instrument}; use tracing_test::traced_test; - use super::*; - use crate::test_utils::{run_relay_server, run_relay_server_with}; + use super::Endpoint; + use crate::{ + endpoint::{ConnectOptions, Connection, ConnectionType, RemoteInfo}, + test_utils::{run_relay_server, run_relay_server_with}, + tls, + watcher::Watcher, + RelayMode, + }; const TEST_ALPN: &[u8] = b"n0/iroh/test"; #[tokio::test] #[traced_test] - async fn test_connect_self() { + async fn test_connect_self() -> Result { let ep = Endpoint::builder() .alpns(vec![TEST_ALPN.to_vec()]) .bind() @@ -2265,12 +2383,13 @@ mod tests { assert!(res.is_err()); let err = res.err().unwrap(); assert!(err.to_string().starts_with("Adding our own address")); + Ok(()) } #[tokio::test] #[traced_test] - async fn endpoint_connect_close() { - let (relay_map, relay_url, _guard) = run_relay_server().await.unwrap(); + async fn endpoint_connect_close() -> Result { + let (relay_map, relay_url, _guard) = run_relay_server().await?; let server_secret_key = SecretKey::generate(rand::thread_rng()); let server_peer_id = server_secret_key.public(); @@ -2284,14 +2403,13 @@ mod tests { .relay_mode(RelayMode::Custom(relay_map)) .insecure_skip_relay_cert_verify(true) .bind() - .await - .unwrap(); + .await?; info!("accepting connection"); - let incoming = ep.accept().await.unwrap(); - let conn = incoming.await.unwrap(); - let mut stream = conn.accept_uni().await.unwrap(); + let incoming = ep.accept().await.e()?; + let conn = incoming.await.e()?; + let mut stream = conn.accept_uni().await.e()?; let mut buf = [0u8; 5]; - stream.read_exact(&mut buf).await.unwrap(); + stream.read_exact(&mut buf).await.e()?; info!("Accepted 1 stream, received {buf:?}. Closing now."); // close the connection conn.close(7u8.into(), b"bye"); @@ -2307,6 +2425,7 @@ mod tests { )) ); info!("server test completed"); + Ok::<_, Error>(()) } .instrument(info_span!("test-server")), ) @@ -2319,18 +2438,17 @@ mod tests { .relay_mode(RelayMode::Custom(relay_map)) .insecure_skip_relay_cert_verify(true) .bind() - .await - .unwrap(); + .await?; info!("client connecting"); let node_addr = NodeAddr::new(server_peer_id).with_relay_url(relay_url); - let conn = ep.connect(node_addr, TEST_ALPN).await.unwrap(); - let mut stream = conn.open_uni().await.unwrap(); + let conn = ep.connect(node_addr, TEST_ALPN).await?; + let mut stream = conn.open_uni().await.e()?; // First write is accepted by server. We need this bit of synchronisation // because if the server closes after simply accepting the connection we can // not be sure our .open_uni() call would succeed as it may already receive // the error. - stream.write_all(b"hello").await.unwrap(); + stream.write_all(b"hello").await.e()?; info!("waiting for closed"); // Remote now closes the connection, we should see an error sometime soon. @@ -2346,6 +2464,7 @@ mod tests { let res = conn.open_uni().await; assert_eq!(res.unwrap_err(), expected_err); info!("client test completed"); + Ok::<_, Error>(()) } .instrument(info_span!("test-client")), ); @@ -2355,19 +2474,23 @@ mod tests { n0_future::future::zip(server, client), ) .await - .expect("timeout"); - server.unwrap(); - client.unwrap(); + .e()?; + server.e()??; + client.e()??; + Ok(()) } /// Test that peers are properly restored #[tokio::test] #[traced_test] - async fn restore_peers() { + async fn restore_peers() -> Result { let secret_key = SecretKey::generate(rand::thread_rng()); /// Create an endpoint for the test. - async fn new_endpoint(secret_key: SecretKey, nodes: Option>) -> Endpoint { + async fn new_endpoint( + secret_key: SecretKey, + nodes: Option>, + ) -> Result { let mut transport_config = quinn::TransportConfig::default(); transport_config.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap())); @@ -2377,11 +2500,7 @@ mod tests { if let Some(nodes) = nodes { builder = builder.known_nodes(nodes); } - builder - .alpns(vec![TEST_ALPN.to_vec()]) - .bind() - .await - .unwrap() + Ok(builder.alpns(vec![TEST_ALPN.to_vec()]).bind().await?) } // create the peer that will be added to the peer map @@ -2393,9 +2512,9 @@ mod tests { info!("setting up first endpoint"); // first time, create a magic endpoint without peers but a peers file and add addressing // information for a peer - let endpoint = new_endpoint(secret_key.clone(), None).await; + let endpoint = new_endpoint(secret_key.clone(), None).await?; assert_eq!(endpoint.remote_info_iter().count(), 0); - endpoint.add_node_addr(node_addr.clone()).unwrap(); + endpoint.add_node_addr(node_addr.clone())?; // Grab the current addrs let node_addrs: Vec = endpoint.remote_info_iter().map(Into::into).collect(); @@ -2408,15 +2527,16 @@ mod tests { info!("restarting endpoint"); // now restart it and check the addressing info of the peer - let endpoint = new_endpoint(secret_key, Some(node_addrs)).await; - let RemoteInfo { mut addrs, .. } = endpoint.remote_info(peer_id).unwrap(); + let endpoint = new_endpoint(secret_key, Some(node_addrs)).await?; + let RemoteInfo { mut addrs, .. } = endpoint.remote_info(peer_id).e()?; let conn_addr = addrs.pop().unwrap().addr; assert_eq!(conn_addr, direct_addr); + Ok(()) } #[tokio::test] #[traced_test] - async fn endpoint_relay_connect_loop() { + async fn endpoint_relay_connect_loop() -> Result { let start = Instant::now(); let n_clients = 5; let n_chunks_per_client = 2; @@ -2437,29 +2557,29 @@ mod tests { .alpns(vec![TEST_ALPN.to_vec()]) .relay_mode(RelayMode::Custom(relay_map)) .bind() - .await - .unwrap(); + .await?; let eps = ep.bound_sockets(); info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "server listening on"); for i in 0..n_clients { let round_start = Instant::now(); info!("[server] round {i}"); - let incoming = ep.accept().await.unwrap(); - let conn = incoming.await.unwrap(); - let node_id = conn.remote_node_id().unwrap(); + let incoming = ep.accept().await.e()?; + let conn = incoming.await.e()?; + let node_id = conn.remote_node_id()?; info!(%i, peer = %node_id.fmt_short(), "accepted connection"); - let (mut send, mut recv) = conn.accept_bi().await.unwrap(); + let (mut send, mut recv) = conn.accept_bi().await.e()?; let mut buf = vec![0u8; chunk_size]; for _i in 0..n_chunks_per_client { - recv.read_exact(&mut buf).await.unwrap(); - send.write_all(&buf).await.unwrap(); + recv.read_exact(&mut buf).await.e()?; + send.write_all(&buf).await.e()?; } - send.finish().unwrap(); - send.stopped().await.unwrap(); - recv.read_to_end(0).await.unwrap(); + send.finish().e()?; + send.stopped().await.e()?; + recv.read_to_end(0).await.e()?; info!(%i, peer = %node_id.fmt_short(), "finished"); info!("[server] round {i} done in {:?}", round_start.elapsed()); } + Ok::<_, Error>(()) } .instrument(error_span!("server")), ) @@ -2479,35 +2599,35 @@ mod tests { .relay_mode(RelayMode::Custom(relay_map)) .secret_key(client_secret_key) .bind() - .await - .unwrap(); + .await?; let eps = ep.bound_sockets(); info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "client bound"); let node_addr = NodeAddr::new(server_node_id).with_relay_url(relay_url); info!(to = ?node_addr, "client connecting"); - let conn = ep.connect(node_addr, TEST_ALPN).await.unwrap(); + let conn = ep.connect(node_addr, TEST_ALPN).await.e()?; info!("client connected"); - let (mut send, mut recv) = conn.open_bi().await.unwrap(); + let (mut send, mut recv) = conn.open_bi().await.e()?; for i in 0..n_chunks_per_client { let mut buf = vec![i; chunk_size]; - send.write_all(&buf).await.unwrap(); - recv.read_exact(&mut buf).await.unwrap(); + send.write_all(&buf).await.e()?; + recv.read_exact(&mut buf).await.e()?; assert_eq!(buf, vec![i; chunk_size]); } - send.finish().unwrap(); - send.stopped().await.unwrap(); - recv.read_to_end(0).await.unwrap(); + send.finish().e()?; + send.stopped().await.e()?; + recv.read_to_end(0).await.e()?; info!("client finished"); ep.close().await; info!("client closed"); + Ok::<_, Error>(()) } .instrument(error_span!("client", %i)) - .await; + .await?; info!("[client] round {i} done in {:?}", round_start.elapsed()); } - server.await.unwrap(); + server.await.e()??; // We appear to have seen this being very slow at times. So ensure we fail if this // test is too slow. We're only making two connections transferring very little @@ -2515,11 +2635,13 @@ mod tests { if start.elapsed() > Duration::from_secs(15) { panic!("Test too slow, something went wrong"); } + + Ok(()) } #[tokio::test] #[traced_test] - async fn endpoint_send_relay_websockets() -> testresult::TestResult { + async fn endpoint_send_relay_websockets() -> Result { let (relay_map, _relay_url, _guard) = run_relay_server().await?; let client = Endpoint::builder() .relay_conn_protocol(Protocol::Websocket) @@ -2539,28 +2661,28 @@ mod tests { let server = server.clone(); async move { let Some(conn) = server.accept().await else { - bail!("Expected an incoming connection"); + snafu::whatever!("Expected an incoming connection"); }; - let conn = conn.await?; - let (mut send, mut recv) = conn.accept_bi().await?; - let data = recv.read_to_end(1000).await?; - send.write_all(&data).await?; - send.finish()?; + let conn = conn.await.e()?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; + let data = recv.read_to_end(1000).await.e()?; + send.write_all(&data).await.e()?; + send.finish().e()?; conn.closed().await; - Ok(()) + Ok::<_, Error>(()) } }); let addr = server.node_addr().initialized().await?; let conn = client.connect(addr, TEST_ALPN).await?; - let (mut send, mut recv) = conn.open_bi().await?; - send.write_all(b"Hello, world!").await?; - send.finish()?; - let data = recv.read_to_end(1000).await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(b"Hello, world!").await.e()?; + send.finish().e()?; + let data = recv.read_to_end(1000).await.e()?; conn.close(0u32.into(), b"bye!"); - task.await??; + task.await.e()??; client.close().await; server.close().await; @@ -2572,17 +2694,17 @@ mod tests { #[tokio::test] #[traced_test] - async fn endpoint_bidi_send_recv_x509() { + async fn endpoint_bidi_send_recv_x509() -> Result { endpoint_bidi_send_recv(tls::Authentication::X509).await } #[tokio::test] #[traced_test] - async fn endpoint_bidi_send_recv_raw_public_key() { + async fn endpoint_bidi_send_recv_raw_public_key() -> Result { endpoint_bidi_send_recv(tls::Authentication::RawPublicKey).await } - async fn endpoint_bidi_send_recv(auth: tls::Authentication) { + async fn endpoint_bidi_send_recv(auth: tls::Authentication) -> Result { let ep1 = Endpoint::builder() .alpns(vec![TEST_ALPN.to_vec()]) .relay_mode(RelayMode::Disabled); @@ -2591,7 +2713,7 @@ mod tests { tls::Authentication::X509 => ep1.tls_x509(), tls::Authentication::RawPublicKey => ep1.tls_raw_public_keys(), }; - let ep1 = ep1.bind().await.unwrap(); + let ep1 = ep1.bind().await?; let ep2 = Endpoint::builder() .alpns(vec![TEST_ALPN.to_vec()]) .relay_mode(RelayMode::Disabled); @@ -2600,47 +2722,49 @@ mod tests { tls::Authentication::X509 => ep2.tls_x509(), tls::Authentication::RawPublicKey => ep2.tls_raw_public_keys(), }; - let ep2 = ep2.bind().await.unwrap(); + let ep2 = ep2.bind().await?; - let ep1_nodeaddr = ep1.node_addr().initialized().await.unwrap(); - let ep2_nodeaddr = ep2.node_addr().initialized().await.unwrap(); - ep1.add_node_addr(ep2_nodeaddr.clone()).unwrap(); - ep2.add_node_addr(ep1_nodeaddr.clone()).unwrap(); + let ep1_nodeaddr = ep1.node_addr().initialized().await?; + let ep2_nodeaddr = ep2.node_addr().initialized().await?; + ep1.add_node_addr(ep2_nodeaddr.clone())?; + ep2.add_node_addr(ep1_nodeaddr.clone())?; let ep1_nodeid = ep1.node_id(); let ep2_nodeid = ep2.node_id(); eprintln!("node id 1 {ep1_nodeid}"); eprintln!("node id 2 {ep2_nodeid}"); - async fn connect_hello(ep: Endpoint, dst: NodeAddr) { - let conn = ep.connect(dst, TEST_ALPN).await.unwrap(); - let (mut send, mut recv) = conn.open_bi().await.unwrap(); + async fn connect_hello(ep: Endpoint, dst: NodeAddr) -> Result { + let conn = ep.connect(dst, TEST_ALPN).await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; info!("sending hello"); - send.write_all(b"hello").await.unwrap(); - send.finish().unwrap(); + send.write_all(b"hello").await.e()?; + send.finish().e()?; info!("receiving world"); - let m = recv.read_to_end(100).await.unwrap(); + let m = recv.read_to_end(100).await.e()?; assert_eq!(m, b"world"); conn.close(1u8.into(), b"done"); + Ok(()) } - async fn accept_world(ep: Endpoint, src: NodeId) { - let incoming = ep.accept().await.unwrap(); - let mut iconn = incoming.accept().unwrap(); - let alpn = iconn.alpn().await.unwrap(); - let conn = iconn.await.unwrap(); - let node_id = conn.remote_node_id().unwrap(); + async fn accept_world(ep: Endpoint, src: NodeId) -> Result { + let incoming = ep.accept().await.e()?; + let mut iconn = incoming.accept().e()?; + let alpn = iconn.alpn().await?; + let conn = iconn.await.e()?; + let node_id = conn.remote_node_id()?; assert_eq!(node_id, src); assert_eq!(alpn, TEST_ALPN); - let (mut send, mut recv) = conn.accept_bi().await.unwrap(); + let (mut send, mut recv) = conn.accept_bi().await.e()?; info!("receiving hello"); - let m = recv.read_to_end(100).await.unwrap(); + let m = recv.read_to_end(100).await.e()?; assert_eq!(m, b"hello"); info!("sending hello"); - send.write_all(b"world").await.unwrap(); - send.finish().unwrap(); + send.write_all(b"world").await.e()?; + send.finish().e()?; match conn.closed().await { ConnectionError::ApplicationClosed(closed) => { assert_eq!(closed.error_code, 1u8.into()); + Ok(()) } _ => panic!("wrong close error"), } @@ -2671,15 +2795,17 @@ mod tests { ), )); - p1_accept.await.unwrap(); - p2_accept.await.unwrap(); - p1_connect.await.unwrap(); - p2_connect.await.unwrap(); + p1_accept.await.e()??; + p2_accept.await.e()??; + p1_connect.await.e()??; + p2_connect.await.e()??; + + Ok(()) } #[tokio::test] #[traced_test] - async fn endpoint_conn_type_becomes_direct() -> TestResult { + async fn endpoint_conn_type_becomes_direct() -> Result { const TIMEOUT: Duration = std::time::Duration::from_secs(15); let (relay_map, _relay_url, _relay_guard) = run_relay_server().await?; let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); @@ -2700,8 +2826,8 @@ mod tests { .bind() .await?; - async fn wait_for_conn_type_direct(ep: &Endpoint, node_id: NodeId) -> TestResult { - let mut stream = ep.conn_type(node_id)?.stream(); + async fn wait_for_conn_type_direct(ep: &Endpoint, node_id: NodeId) -> Result { + let mut stream = ep.conn_type(node_id).expect("connection exists").stream(); let src = ep.node_id().fmt_short(); let dst = node_id.fmt_short(); while let Some(conn_type) = stream.next().await { @@ -2710,12 +2836,12 @@ mod tests { return Ok(()); } } - panic!("conn_type stream ended before `ConnectionType::Direct`"); + snafu::whatever!("conn_type stream ended before `ConnectionType::Direct`"); } - async fn accept(ep: &Endpoint) -> TestResult { + async fn accept(ep: &Endpoint) -> Result { let incoming = ep.accept().await.expect("ep closed"); - let conn = incoming.await?; + let conn = incoming.await.e()?; let node_id = conn.remote_node_id()?; tracing::info!(node_id=%node_id.fmt_short(), "accepted connection"); Ok(conn) @@ -2733,52 +2859,51 @@ mod tests { let ep1_side = tokio::time::timeout(TIMEOUT, async move { let conn = accept(&ep1).await?; - let mut send = conn.open_uni().await?; + let mut send = conn.open_uni().await.e()?; wait_for_conn_type_direct(&ep1, ep2_nodeid).await?; - send.write_all(b"Conn is direct").await?; - send.finish()?; + send.write_all(b"Conn is direct").await.e()?; + send.finish().e()?; conn.closed().await; - TestResult::Ok(()) + Ok::<(), Error>(()) }); let ep2_side = tokio::time::timeout(TIMEOUT, async move { let conn = ep2.connect(ep1_nodeaddr, TEST_ALPN).await?; - let mut recv = conn.accept_uni().await?; + let mut recv = conn.accept_uni().await.e()?; wait_for_conn_type_direct(&ep2, ep1_nodeid).await?; - let read = recv.read_to_end(100).await?; + let read = recv.read_to_end(100).await.e()?; assert_eq!(read, b"Conn is direct".to_vec()); conn.close(0u32.into(), b"done"); conn.closed().await; - TestResult::Ok(()) + Ok::<(), Error>(()) }); let res_ep1 = AbortOnDropHandle::new(tokio::spawn(ep1_side)); let res_ep2 = AbortOnDropHandle::new(tokio::spawn(ep2_side)); - let (r1, r2) = tokio::try_join!(res_ep1, res_ep2)?; - r1??; - r2??; + let (r1, r2) = tokio::try_join!(res_ep1, res_ep2).e()?; + r1.e()??; + r2.e()??; Ok(()) } #[tokio::test] #[traced_test] - async fn test_direct_addresses_no_stun_relay() { - let (relay_map, _, _guard) = run_relay_server_with(None, false).await.unwrap(); + async fn test_direct_addresses_no_stun_relay() -> Result { + let (relay_map, _, _guard) = run_relay_server_with(None, false).await?; let ep = Endpoint::builder() .alpns(vec![TEST_ALPN.to_vec()]) .relay_mode(RelayMode::Custom(relay_map)) .insecure_skip_relay_cert_verify(true) .bind() - .await - .unwrap(); + .await?; tokio::time::timeout(Duration::from_secs(10), ep.direct_addresses().initialized()) .await - .unwrap() - .unwrap(); + .e()??; + Ok(()) } async fn spawn_0rtt_server(secret_key: SecretKey, log_span: tracing::Span) -> Result { @@ -2795,7 +2920,7 @@ mod tests { let server = server.clone(); async move { while let Some(incoming) = server.accept().await { - let connecting = incoming.accept()?; + let connecting = incoming.accept().e()?; let conn = match connecting.into_0rtt() { Ok((conn, _)) => { info!("0rtt accepted"); @@ -2803,18 +2928,18 @@ mod tests { } Err(connecting) => { info!("0rtt denied"); - connecting.await? + connecting.await.e()? } }; - let (mut send, mut recv) = conn.accept_bi().await?; - let data = recv.read_to_end(10_000_000).await?; - send.write_all(&data).await?; - send.finish()?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; + let data = recv.read_to_end(10_000_000).await.e()?; + send.write_all(&data).await.e()?; + send.finish().e()?; // Stay alive until the other side closes the connection. conn.closed().await; } - anyhow::Ok(()) + Ok::<_, Error>(()) } .instrument(log_span) }); @@ -2822,21 +2947,19 @@ mod tests { Ok(server) } - async fn connect_client_0rtt_expect_err( - client: &Endpoint, - server_addr: NodeAddr, - ) -> Result<()> { + async fn connect_client_0rtt_expect_err(client: &Endpoint, server_addr: NodeAddr) -> Result { let conn = client .connect_with_opts(server_addr, TEST_ALPN, ConnectOptions::new()) .await? .into_0rtt() .expect_err("expected 0rtt to fail") - .await?; + .await + .e()?; - let (mut send, mut recv) = conn.open_bi().await?; - send.write_all(b"hello").await?; - send.finish()?; - let received = recv.read_to_end(1_000).await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(b"hello").await.e()?; + send.finish().e()?; + let received = recv.read_to_end(1_000).await.e()?; assert_eq!(&received, b"hello"); conn.close(0u32.into(), b"thx"); Ok(()) @@ -2846,18 +2969,18 @@ mod tests { client: &Endpoint, server_addr: NodeAddr, expect_server_accepts: bool, - ) -> Result<()> { + ) -> Result { let (conn, accepted_0rtt) = client .connect_with_opts(server_addr, TEST_ALPN, ConnectOptions::new()) .await? .into_0rtt() - .map_err(|_| "0rtt failed") - .unwrap(); + .ok() + .e()?; // This is how we send data in 0-RTT: - let (mut send, recv) = conn.open_bi().await?; - send.write_all(b"hello").await?; - send.finish()?; + let (mut send, recv) = conn.open_bi().await.e()?; + send.write_all(b"hello").await.e()?; + send.finish().e()?; // When this resolves, we've gotten a response from the server about whether the 0-RTT data above was accepted: let accepted = accepted_0rtt.await; assert_eq!(accepted, expect_server_accepts); @@ -2865,12 +2988,12 @@ mod tests { recv } else { // in this case we need to re-send data by re-creating the connection. - let (mut send, recv) = conn.open_bi().await?; - send.write_all(b"hello").await?; - send.finish()?; + let (mut send, recv) = conn.open_bi().await.e()?; + send.write_all(b"hello").await.e()?; + send.finish().e()?; recv }; - let received = recv.read_to_end(1_000).await?; + let received = recv.read_to_end(1_000).await.e()?; assert_eq!(&received, b"hello"); conn.close(0u32.into(), b"thx"); Ok(()) @@ -2878,7 +3001,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_0rtt() -> TestResult { + async fn test_0rtt() -> Result { let client = Endpoint::builder() .relay_mode(RelayMode::Disabled) .bind() @@ -2905,7 +3028,7 @@ mod tests { // receive into the respective "bucket" for the recipient. #[tokio::test] #[traced_test] - async fn test_0rtt_non_consecutive() -> TestResult { + async fn test_0rtt_non_consecutive() -> Result { let client = Endpoint::builder() .relay_mode(RelayMode::Disabled) .bind() @@ -2940,7 +3063,7 @@ mod tests { // Test whether 0-RTT is possible after a restart: #[tokio::test] #[traced_test] - async fn test_0rtt_after_server_restart() -> TestResult { + async fn test_0rtt_after_server_restart() -> Result { let client = Endpoint::builder() .relay_mode(RelayMode::Disabled) .bind() @@ -2969,7 +3092,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn graceful_close() -> testresult::TestResult { + async fn graceful_close() -> Result { let client = Endpoint::builder().bind().await?; let server = Endpoint::builder() .alpns(vec![TEST_ALPN.to_vec()]) @@ -2977,25 +3100,25 @@ mod tests { .await?; let server_addr = server.node_addr().initialized().await?; let server_task = tokio::spawn(async move { - let incoming = server.accept().await.unwrap(); - let conn = incoming.await?; - let (mut send, mut recv) = conn.accept_bi().await?; - let msg = recv.read_to_end(1_000).await?; - send.write_all(&msg).await?; - send.finish()?; + let incoming = server.accept().await.e()?; + let conn = incoming.await.e()?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; + let msg = recv.read_to_end(1_000).await.e()?; + send.write_all(&msg).await.e()?; + send.finish().e()?; let close_reason = conn.closed().await; - testresult::TestResult::Ok(close_reason) + Ok::<_, Error>(close_reason) }); let conn = client.connect(server_addr, TEST_ALPN).await?; - let (mut send, mut recv) = conn.open_bi().await?; - send.write_all(b"Hello, world!").await?; - send.finish()?; - recv.read_to_end(1_000).await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(b"Hello, world!").await.e()?; + send.finish().e()?; + recv.read_to_end(1_000).await.e()?; conn.close(42u32.into(), b"thanks, bye!"); client.close().await; - let close_err = server_task.await??; + let close_err = server_task.await.e()??; let ConnectionError::ApplicationClosed(app_close) = close_err else { panic!("Unexpected close reason: {close_err:?}"); }; @@ -3006,10 +3129,11 @@ mod tests { Ok(()) } + #[cfg(feature = "metrics")] #[tokio::test] #[traced_test] - async fn metrics_smoke() -> testresult::TestResult { - use iroh_metrics::Registry; + async fn metrics_smoke() -> Result { + use iroh_metrics::{MetricsSource, Registry}; let secret_key = SecretKey::from_bytes(&[0u8; 32]); let client = Endpoint::builder() @@ -3026,24 +3150,19 @@ mod tests { .await?; let server_addr = server.node_addr().initialized().await?; let server_task = tokio::task::spawn(async move { - let conn = server - .accept() - .await - .context("expected conn")? - .accept()? - .await?; - let mut uni = conn.accept_uni().await?; - uni.read_to_end(10).await?; + let conn = server.accept().await.e()?.accept().e()?.await.e()?; + let mut uni = conn.accept_uni().await.e()?; + uni.read_to_end(10).await.e()?; drop(conn); - anyhow::Ok(server) + Ok::<_, Error>(server) }); let conn = client.connect(server_addr, TEST_ALPN).await?; - let mut uni = conn.open_uni().await?; - uni.write_all(b"helloworld").await?; - uni.finish()?; + let mut uni = conn.open_uni().await.e()?; + uni.write_all(b"helloworld").await.e()?; + uni.finish().e()?; conn.closed().await; drop(conn); - let server = server_task.await??; + let server = server_task.await.e()??; let m = client.metrics(); assert_eq!(m.magicsock.num_direct_conns_added.get(), 1); @@ -3080,7 +3199,7 @@ mod tests { accept_alpns: Vec>, primary_connect_alpn: &[u8], secondary_connect_alpns: Vec>, - ) -> testresult::TestResult>> { + ) -> Result>> { let client = Endpoint::builder() .relay_mode(RelayMode::Disabled) .bind() @@ -3094,10 +3213,10 @@ mod tests { let server_task = tokio::spawn({ let server = server.clone(); async move { - let incoming = server.accept().await.unwrap(); - let conn = incoming.await?; + let incoming = server.accept().await.e()?; + let conn = incoming.await.e()?; conn.close(0u32.into(), b"bye!"); - testresult::TestResult::Ok(conn.alpn()) + Ok::<_, n0_snafu::Error>(conn.alpn()) } }); @@ -3108,13 +3227,13 @@ mod tests { ConnectOptions::new().with_additional_alpns(secondary_connect_alpns), ) .await?; - let conn = conn.await?; + let conn = conn.await.e()?; let client_alpn = conn.alpn(); conn.closed().await; client.close().await; server.close().await; - let server_alpn = server_task.await??; + let server_alpn = server_task.await.e()??; assert_eq!(client_alpn, server_alpn); @@ -3123,7 +3242,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn connect_multiple_alpn_negotiated() -> testresult::TestResult { + async fn connect_multiple_alpn_negotiated() -> Result { const ALPN_ONE: &[u8] = b"alpn/1"; const ALPN_TWO: &[u8] = b"alpn/2"; @@ -3174,7 +3293,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn watch_net_report() -> testresult::TestResult { + async fn watch_net_report() -> Result { let endpoint = Endpoint::builder() .relay_mode(RelayMode::Staging) .bind() diff --git a/iroh/src/key.rs b/iroh/src/key.rs index cfc29d3793d..39f1fc26862 100644 --- a/iroh/src/key.rs +++ b/iroh/src/key.rs @@ -3,6 +3,8 @@ use std::fmt::Debug; use aead::Buffer; +use nested_enum_utils::common_fields; +use snafu::{ensure, ResultExt, Snafu}; pub(crate) const NONCE_LEN: usize = 24; @@ -18,14 +20,20 @@ pub(super) fn secret_ed_box(key: &ed25519_dalek::SigningKey) -> crypto_box::Secr pub struct SharedSecret(crypto_box::ChaChaBox); /// Errors that can occur during [`SharedSecret::open`]. -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum DecryptionError { /// The nonce had the wrong size. - #[error("Invalid nonce")] - InvalidNonce, + #[snafu(display("Invalid nonce"))] + InvalidNonce {}, /// AEAD decryption failed. - #[error("Aead error")] - Aead(aead::Error), + #[snafu(display("Aead error"))] + Aead { source: aead::Error }, } impl Debug for SharedSecret { @@ -54,19 +62,17 @@ impl SharedSecret { /// Opens the ciphertext, which must have been created using `Self::seal`, and places the clear text into the provided buffer. pub fn open(&self, buffer: &mut dyn Buffer) -> Result<(), DecryptionError> { use aead::AeadInPlace; - if buffer.len() < NONCE_LEN { - return Err(DecryptionError::InvalidNonce); - } + ensure!(buffer.len() >= NONCE_LEN, InvalidNonceSnafu); let offset = buffer.len() - NONCE_LEN; let nonce: [u8; NONCE_LEN] = buffer.as_ref()[offset..] .try_into() - .map_err(|_| DecryptionError::InvalidNonce)?; + .map_err(|_| InvalidNonceSnafu.build())?; buffer.truncate(offset); self.0 .decrypt_in_place(&nonce.into(), &[], buffer) - .map_err(DecryptionError::Aead)?; + .context(AeadSnafu)?; Ok(()) } diff --git a/iroh/src/lib.rs b/iroh/src/lib.rs index 42b038e3c0f..2f44c3420e6 100644 --- a/iroh/src/lib.rs +++ b/iroh/src/lib.rs @@ -9,12 +9,16 @@ //! //! ```no_run //! # use iroh::{Endpoint, NodeAddr}; -//! # async fn wrapper() -> testresult::TestResult { +//! # use n0_snafu::ResultExt; +//! # async fn wrapper() -> n0_snafu::Result { //! let addr: NodeAddr = todo!(); //! let ep = Endpoint::builder().bind().await?; //! let conn = ep.connect(addr, b"my-alpn").await?; -//! let mut send_stream = conn.open_uni().await?; -//! send_stream.write_all(b"msg").await?; +//! let mut send_stream = conn.open_uni().await.context("unable to open uni")?; +//! send_stream +//! .write_all(b"msg") +//! .await +//! .context("unable to write all")?; //! # Ok(()) //! # } //! ``` @@ -23,15 +27,24 @@ //! //! ```no_run //! # use iroh::{Endpoint, NodeAddr}; -//! # async fn wrapper() -> testresult::TestResult { +//! # use n0_snafu::ResultExt; +//! # async fn wrapper() -> n0_snafu::Result { //! let ep = Endpoint::builder() //! .alpns(vec![b"my-alpn".to_vec()]) //! .bind() //! .await?; -//! let conn = ep.accept().await.ok_or("err")?.await?; -//! let mut recv_stream = conn.accept_uni().await?; +//! let conn = ep +//! .accept() +//! .await +//! .context("accept error")? +//! .await +//! .context("connecting error")?; +//! let mut recv_stream = conn.accept_uni().await.context("unable to open uni")?; //! let mut buf = [0u8; 3]; -//! recv_stream.read_exact(&mut buf).await?; +//! recv_stream +//! .read_exact(&mut buf) +//! .await +//! .context("unable to read")?; //! # Ok(()) //! # } //! ``` @@ -157,8 +170,8 @@ //! The central struct is the [`Endpoint`], which allows you to connect to other nodes: //! //! ```no_run -//! use anyhow::Result; //! use iroh::{Endpoint, NodeAddr}; +//! use n0_snafu::{Result, ResultExt}; //! //! async fn connect(addr: NodeAddr) -> Result<()> { //! // The Endpoint is the central object that manages an iroh node. @@ -166,10 +179,10 @@ //! //! // Establish a QUIC connection, open a bi-directional stream, exchange messages. //! let conn = ep.connect(addr, b"hello-world").await?; -//! let (mut send_stream, mut recv_stream) = conn.open_bi().await?; -//! send_stream.write_all(b"hello").await?; -//! send_stream.finish()?; -//! let _msg = recv_stream.read_to_end(10).await?; +//! let (mut send_stream, mut recv_stream) = conn.open_bi().await.context("open bi")?; +//! send_stream.write_all(b"hello").await.context("write")?; +//! send_stream.finish().context("finish")?; +//! let _msg = recv_stream.read_to_end(10).await.context("read")?; //! //! // Gracefully close the connection and endpoint. //! conn.close(1u8.into(), b"done"); @@ -182,9 +195,9 @@ //! Every [`Endpoint`] can also accept connections: //! //! ```no_run -//! use anyhow::{Context, Result}; //! use futures_lite::StreamExt; //! use iroh::{Endpoint, NodeAddr}; +//! use n0_snafu::{Result, ResultExt}; //! //! async fn accept() -> Result<()> { //! // To accept connections at least one ALPN must be configured. @@ -194,11 +207,16 @@ //! .await?; //! //! // Accept a QUIC connection, accept a bi-directional stream, exchange messages. -//! let conn = ep.accept().await.context("no incoming connection")?.await?; -//! let (mut send_stream, mut recv_stream) = conn.accept_bi().await?; -//! let _msg = recv_stream.read_to_end(10).await?; -//! send_stream.write_all(b"world").await?; -//! send_stream.finish()?; +//! let conn = ep +//! .accept() +//! .await +//! .context("no incoming connection")? +//! .await +//! .context("accept conn")?; +//! let (mut send_stream, mut recv_stream) = conn.accept_bi().await.context("accept stream")?; +//! let _msg = recv_stream.read_to_end(10).await.context("read")?; +//! send_stream.write_all(b"world").await.context("write")?; +//! send_stream.finish().context("finish")?; //! //! // Wait for the client to close the connection and gracefully close the endpoint. //! conn.closed().await; diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 21936aa2cbf..3bd24a41d16 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -28,7 +28,6 @@ use std::{ task::{Context, Poll, Waker}, }; -use anyhow::{anyhow, Context as _, Result}; use atomic_waker::AtomicWaker; use bytes::Bytes; use concurrent_queue::ConcurrentQueue; @@ -41,13 +40,18 @@ use n0_future::{ time::{self, Duration, Instant}, FutureExt, StreamExt, }; -use netwatch::{interfaces, netmon}; +use nested_enum_utils::common_fields; +use netwatch::{ + interfaces, + netmon::{self, CallbackToken}, +}; #[cfg(not(wasm_browser))] use netwatch::{ip::LocalAddresses, UdpSocket}; use quinn::{AsyncUdpSocket, ServerConfig}; use rand::{seq::SliceRandom, Rng, SeedableRng}; use relay_actor::RelaySendItem; use smallvec::{smallvec, SmallVec}; +use snafu::{IntoError, ResultExt, Snafu}; use tokio::sync::{self, mpsc, Mutex}; use tokio_util::sync::CancellationToken; use tracing::{ @@ -75,7 +79,7 @@ use crate::{ discovery::{Discovery, DiscoveryItem, DiscoverySubscribers, NodeData, UserData}, key::{public_ed_box, secret_ed_box, DecryptionError, SharedSecret}, metrics::EndpointMetrics, - net_report::{self, IpMappedAddresses, Report}, + net_report::{self, IpMappedAddresses, Report, ReportError}, watcher::{self, Watchable}, }; @@ -280,9 +284,27 @@ pub(crate) struct SocketState { local_addrs: std::sync::RwLock<(SocketAddr, Option)>, } +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[snafu(visibility(pub(crate)))] +#[non_exhaustive] +pub enum AddNodeAddrError { + #[snafu(display("Empty addressing info"))] + Empty {}, + #[snafu(display("Empty addressing info, {pruned} direct address have been pruned"))] + EmptyPruned { pruned: usize }, + #[snafu(display("Adding our own address is not supported"))] + OwnAddress {}, +} + impl MagicSock { /// Creates a magic [`MagicSock`] listening on [`Options::addr_v4`] and [`Options::addr_v6`]. - pub(crate) async fn spawn(opts: Options) -> Result { + pub(crate) async fn spawn(opts: Options) -> Result { Handle::new(opts).await } @@ -388,19 +410,17 @@ impl MagicSock { self.my_relay.watch() } - /// Returns a [`Watcher`] that reports the [`ConnectionType`] we have to the + /// Returns a [`watcher::Direct`] that reports the [`ConnectionType`] we have to the /// given `node_id`. /// - /// This gets us a copy of the [`Watcher`] for the [`Watchable`] with a [`ConnectionType`] + /// This gets us a copy of the [`watcher::Direct`] for the [`Watchable`] with a [`ConnectionType`] /// that the `NodeMap` stores for each `node_id`'s endpoint. /// /// # Errors /// - /// Will return an error if there is no address information known about the + /// Will return `None` if there is no address information known about the /// given `node_id`. - /// - /// [`Watcher`]: crate::watcher::Watcher - pub(crate) fn conn_type(&self, node_id: NodeId) -> Result> { + pub(crate) fn conn_type(&self, node_id: NodeId) -> Option> { self.node_map.conn_type(node_id) } @@ -411,8 +431,12 @@ impl MagicSock { /// Add addresses for a node to the magic socket's addresbook. #[instrument(skip_all, fields(me = %self.me))] - pub fn add_node_addr(&self, mut addr: NodeAddr, source: node_map::Source) -> Result<()> { - let mut pruned = 0; + pub fn add_node_addr( + &self, + mut addr: NodeAddr, + source: node_map::Source, + ) -> Result<(), AddNodeAddrError> { + let mut pruned: usize = 0; for my_addr in self.direct_addrs.sockaddrs() { if addr.direct_addresses.remove(&my_addr) { warn!( node_id=addr.node_id.fmt_short(), %my_addr, %source, "not adding our addr for node"); @@ -424,11 +448,9 @@ impl MagicSock { .add_node_addr(addr, source, &self.metrics.magicsock); Ok(()) } else if pruned != 0 { - Err(anyhow::anyhow!( - "empty addressing info, {pruned} direct addresses have been pruned" - )) + Err(EmptyPrunedSnafu { pruned }.build()) } else { - Err(anyhow::anyhow!("empty addressing info")) + Err(EmptySnafu.build()) } } @@ -1161,12 +1183,12 @@ impl MagicSock { sealed_box.to_vec(), ) { Ok(dm) => dm, - Err(DiscoBoxError::Open(err)) => { - warn!(?err, "failed to open disco box"); + Err(DiscoBoxError::Open { source, .. }) => { + warn!(?source, "failed to open disco box"); self.metrics.magicsock.recv_disco_bad_key.inc(); return; } - Err(DiscoBoxError::Parse(err)) => { + Err(DiscoBoxError::Parse { source, .. }) => { // Couldn't parse it, but it was inside a correctly // signed box, so just ignore it, assuming it's from a // newer version of Tailscale that we don't @@ -1174,7 +1196,7 @@ impl MagicSock { // be too spammy for old clients. self.metrics.magicsock.recv_disco_bad_parse.inc(); - debug!(?err, "failed to parse disco message"); + debug!(?source, "failed to parse disco message"); return; } }; @@ -1708,9 +1730,30 @@ impl DirectAddrUpdateState { } } +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum CreateHandleError { + #[snafu(display("Failed to create bind sockets"))] + BindSockets { source: io::Error }, + #[snafu(display("Failed to create internal quinn endpoint"))] + CreateQuinnEndpoint { source: io::Error }, + #[snafu(display("Failed to create socket state"))] + CreateSocketState { source: io::Error }, + #[snafu(display("Failed to create netmon monitor"))] + CreateNetmonMonitor { source: netmon::Error }, + #[snafu(display("Failed to subscribe netmon monitor"))] + SubscribeNetmonMonitor { source: netmon::Error }, +} + impl Handle { /// Creates a magic [`MagicSock`] listening on [`Options::addr_v4`] and [`Options::addr_v6`]. - async fn new(opts: Options) -> Result { + async fn new(opts: Options) -> Result { let me = opts.secret_key.public().fmt_short(); Self::with_name(me, opts) @@ -1718,7 +1761,7 @@ impl Handle { .await } - async fn with_name(me: String, opts: Options) -> Result { + async fn with_name(me: String, opts: Options) -> Result { let Options { addr_v4, addr_v6, @@ -1740,10 +1783,13 @@ impl Handle { } = opts; #[cfg(not(wasm_browser))] - let actor_sockets = ActorSocketState::bind(addr_v4, addr_v6, metrics.portmapper.clone())?; + let actor_sockets = ActorSocketState::bind(addr_v4, addr_v6, metrics.portmapper.clone()) + .context(BindSocketsSnafu)?; #[cfg(not(wasm_browser))] - let sockets = actor_sockets.msock_socket_state()?; + let sockets = actor_sockets + .msock_socket_state() + .context(CreateSocketStateSnafu)?; let ip_mapped_addrs = IpMappedAddresses::default(); @@ -1755,7 +1801,7 @@ impl Handle { #[cfg(not(wasm_browser))] Some(ip_mapped_addrs.clone()), metrics.net_report.clone(), - )?; + ); let (actor_sender, actor_receiver) = mpsc::channel(256); let (relay_actor_sender, relay_actor_receiver) = mpsc::channel(256); @@ -1823,7 +1869,8 @@ impl Handle { Arc::new(quinn::TokioRuntime), #[cfg(wasm_browser)] Arc::new(crate::web_runtime::WebRuntime), - )?; + ) + .context(CreateQuinnEndpointSnafu)?; let mut actor_tasks = JoinSet::default(); @@ -1850,7 +1897,9 @@ impl Handle { } }); - let network_monitor = netmon::Monitor::new().await?; + let network_monitor = netmon::Monitor::new() + .await + .context(CreateNetmonMonitorSnafu)?; let qad_endpoint = endpoint.clone(); // create a client config for the endpoint to use for QUIC address discovery @@ -1879,6 +1928,19 @@ impl Handle { #[cfg(wasm_browser)] let net_report_config = net_report::Options::default(); + // Setup network monitoring + let (link_change_s, link_change_r) = mpsc::channel(8); + let netmon_token = network_monitor + .subscribe(move |is_major| { + let link_change_s = link_change_s.clone(); + async move { + link_change_s.send(is_major).await.ok(); + } + .boxed() + }) + .await + .context(SubscribeNetmonMonitorSnafu)?; + actor_tasks.spawn({ let msock = msock.clone(); async move { @@ -1898,9 +1960,7 @@ impl Handle { net_report_config, }; - if let Err(err) = actor.run().await { - warn!("relay handler errored: {:?}", err); - } + actor.run(link_change_r, netmon_token).await; } .instrument(info_span!("actor")) }); @@ -2022,17 +2082,31 @@ impl DiscoSecrets { node_id: PublicKey, mut sealed_box: Vec, ) -> Result { - self.get(secret, node_id, |secret| secret.open(&mut sealed_box))?; - disco::Message::from_bytes(&sealed_box).map_err(DiscoBoxError::Parse) + self.get(secret, node_id, |secret| secret.open(&mut sealed_box)) + .context(OpenSnafu)?; + disco::Message::from_bytes(&sealed_box).context(ParseSnafu) } } -#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] enum DiscoBoxError { - #[error("Failed to open crypto box")] - Open(#[from] DecryptionError), - #[error("Failed to parse disco message")] - Parse(anyhow::Error), + #[snafu(display("Failed to open crypto box"))] + Open { + #[snafu(source(from(DecryptionError, Box::new)))] + source: Box, + }, + #[snafu(display("Failed to parse disco message"))] + Parse { + #[snafu(source(from(disco::ParseError, Box::new)))] + source: Box, + }, } /// Creates a sender and receiver pair for sending datagrams to the [`RelayActor`]. @@ -2126,6 +2200,13 @@ struct RelayDatagramRecvQueue { waker: AtomicWaker, } +#[derive(Debug, Snafu)] +#[non_exhaustive] +enum RelayRecvDatagramError { + #[snafu(display("Queue is closed"))] + Closed, +} + impl RelayDatagramRecvQueue { /// Creates a new, empty queue with a fixed size bound of 512 items. fn new() -> Self { @@ -2160,7 +2241,10 @@ impl RelayDatagramRecvQueue { /// The reason this method is made available as `&self` is because /// the interface for quinn's [`AsyncUdpSocket::poll_recv`] requires us /// to be able to poll from `&self`. - fn poll_recv(&self, cx: &mut Context) -> Poll> { + fn poll_recv( + &self, + cx: &mut Context, + ) -> Poll> { match self.queue.pop() { Ok(value) => Poll::Ready(Ok(value)), Err(concurrent_queue::PopError::Empty) => { @@ -2174,11 +2258,11 @@ impl RelayDatagramRecvQueue { Err(concurrent_queue::PopError::Empty) => Poll::Pending, Err(concurrent_queue::PopError::Closed) => { self.waker.take(); - Poll::Ready(Err(anyhow!("Queue closed"))) + Poll::Ready(Err(ClosedSnafu.build())) } } } - Err(concurrent_queue::PopError::Closed) => Poll::Ready(Err(anyhow!("Queue closed"))), + Err(concurrent_queue::PopError::Closed) => Poll::Ready(Err(ClosedSnafu.build())), } } } @@ -2334,7 +2418,10 @@ impl quinn::UdpPoller for IoPoller { enum ActorMessage { Shutdown, EndpointPingExpired(usize, stun_rs::TransactionId), - NetReport(Result>>, &'static str), + NetReport( + Result>, NetReportError>, + &'static str, + ), NetworkChange, #[cfg(test)] ForceNetworkChange(bool), @@ -2388,7 +2475,7 @@ impl ActorSocketState { addr_v4: Option, addr_v6: Option, metrics: Arc, - ) -> Result { + ) -> io::Result { let port_mapper = portmapper::Client::with_metrics(Default::default(), metrics); let (v4, v6) = Self::bind_sockets(addr_v4, addr_v6)?; @@ -2419,9 +2506,9 @@ impl ActorSocketState { fn bind_sockets( addr_v4: Option, addr_v6: Option, - ) -> Result<(Arc, Option>)> { + ) -> io::Result<(Arc, Option>)> { let addr_v4 = addr_v4.unwrap_or_else(|| SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); - let v4 = Arc::new(bind_with_fallback(SocketAddr::V4(addr_v4)).context("bind IPv4 failed")?); + let v4 = Arc::new(bind_with_fallback(SocketAddr::V4(addr_v4))?); let ip4_port = v4.local_addr()?.port(); let ip6_port = ip4_port.checked_add(1).unwrap_or(ip4_port - 1); @@ -2438,7 +2525,7 @@ impl ActorSocketState { Ok((v4, v6)) } - fn msock_socket_state(&self) -> Result { + fn msock_socket_state(&self) -> io::Result { let ipv4_addr = self.v4.local_addr()?; let ipv6_addr = self.v6.as_ref().and_then(|c| c.local_addr().ok()); @@ -2453,21 +2540,19 @@ impl ActorSocketState { } } -impl Actor { - async fn run(mut self) -> Result<()> { - // Setup network monitoring - let (link_change_s, mut link_change_r) = mpsc::channel(8); - let _token = self - .network_monitor - .subscribe(move |is_major| { - let link_change_s = link_change_s.clone(); - async move { - link_change_s.send(is_major).await.ok(); - } - .boxed() - }) - .await?; +#[derive(Debug, Snafu)] +#[non_exhaustive] +enum NetReportError { + #[snafu(display("Net report not received"))] + NotReceived, + #[snafu(display("Net report timed out"))] + Timeout, + #[snafu(display("Net report encountered an error"))] + NetReport { source: ReportError }, +} +impl Actor { + async fn run(mut self, mut link_change_r: mpsc::Receiver, _token: CallbackToken) { // Let the the heartbeat only start a couple seconds later #[cfg(not(wasm_browser))] let mut direct_addr_heartbeat_timer = time::interval_at( @@ -2515,7 +2600,7 @@ impl Actor { trace!(?msg, "tick: msg"); self.msock.metrics.magicsock.actor_tick_msg.inc(); if self.handle_actor_message(msg).await { - return Ok(()); + return; } } tick = self.periodic_re_stun_timer.tick() => { @@ -2907,11 +2992,11 @@ impl Actor { let msg_sender = self.msg_sender.clone(); task::spawn(async move { let report = time::timeout(NET_REPORT_TIMEOUT, rx).await; - let report: anyhow::Result<_> = match report { + let report = match report { Ok(Ok(Ok(report))) => Ok(Some(report)), - Ok(Ok(Err(err))) => Err(err), - Ok(Err(_)) => Err(anyhow!("net_report report not received")), - Err(err) => Err(anyhow!("net_report report timeout: {:?}", err)), + Ok(Ok(Err(err))) => Err(NetReportSnafu.into_error(err)), + Ok(Err(_)) => Err(NotReceivedSnafu.build()), + Err(_) => Err(TimeoutSnafu.build()), }; msg_sender .send(ActorMessage::NetReport(report, why)) @@ -3096,13 +3181,13 @@ fn new_re_stun_timer(initial_delay: bool) -> time::Interval { } #[cfg(not(wasm_browser))] -fn bind_with_fallback(mut addr: SocketAddr) -> anyhow::Result { +fn bind_with_fallback(mut addr: SocketAddr) -> io::Result { debug!(%addr, "binding"); // First try binding a preferred port, if specified match UdpSocket::bind_full(addr) { Ok(socket) => { - let local_addr = socket.local_addr().context("UDP socket not bound")?; + let local_addr = socket.local_addr()?; debug!(%addr, %local_addr, "successfully bound"); return Ok(socket); } @@ -3110,14 +3195,14 @@ fn bind_with_fallback(mut addr: SocketAddr) -> anyhow::Result { debug!(%addr, "failed to bind: {err:#}"); // If that was already the fallback port, then error out if addr.port() == 0 { - return Err(err.into()); + return Err(err); } } } // Otherwise, try binding with port 0 addr.set_port(0); - UdpSocket::bind_full(addr).context("failed to bind on fallback port") + UdpSocket::bind_full(addr) } /// The discovered direct addresses of this [`MagicSock`]. @@ -3233,8 +3318,8 @@ fn split_packets(transmit: &quinn_udp::Transmit) -> RelayContents { pub(crate) struct NodeIdMappedAddr(Ipv6Addr); /// Can occur when converting a [`SocketAddr`] to an [`NodeIdMappedAddr`] -#[derive(Debug, thiserror::Error)] -#[error("Failed to convert")] +#[derive(Debug, Snafu)] +#[snafu(display("Failed to convert"))] pub struct NodeIdMappedAddrError; /// Counter to always generate unique addresses for [`NodeIdMappedAddr`]. @@ -3281,7 +3366,7 @@ impl NodeIdMappedAddr { impl TryFrom for NodeIdMappedAddr { type Error = NodeIdMappedAddrError; - fn try_from(value: Ipv6Addr) -> std::result::Result { + fn try_from(value: Ipv6Addr) -> Result { let octets = value.octets(); if octets[0] == Self::ADDR_PREFIXL && octets[1..6] == Self::ADDR_GLOBAL_ID @@ -3447,15 +3532,29 @@ impl NetInfo { #[cfg(test)] mod tests { - use anyhow::Context; - use rand::RngCore; + use std::{collections::BTreeSet, sync::Arc, time::Duration}; + + use bytes::Bytes; + use data_encoding::HEXLOWER; + use iroh_base::{NodeAddr, NodeId, PublicKey, RelayUrl, SecretKey}; + use iroh_relay::RelayMap; + use n0_future::{time, StreamExt}; + use n0_snafu::{Result, ResultExt}; + use quinn::ServerConfig; + use rand::{Rng, RngCore}; + use tokio::task::JoinSet; use tokio_util::task::AbortOnDropHandle; + use tracing::{debug, error, info, info_span, instrument, Instrument}; use tracing_test::traced_test; - use super::*; + use super::{split_packets, NodeIdMappedAddr, Options}; use crate::{ defaults::staging::{self, EU_RELAY_HOSTNAME}, dns::DnsResolver, + endpoint::{DirectAddr, PathSelection, Source}, + magicsock::{ + node_map, Handle, MagicSock, RelayContents, RelayDatagramRecvQueue, RelayRecvDatagram, + }, tls, watcher::Watcher as _, Endpoint, RelayMode, @@ -3522,7 +3621,7 @@ mod tests { } impl MagicStack { - async fn new(relay_mode: RelayMode) -> Result { + async fn new(relay_mode: RelayMode) -> Self { let secret_key = SecretKey::generate(rand::thread_rng()); let mut transport_config = quinn::TransportConfig::default(); @@ -3534,12 +3633,13 @@ mod tests { .relay_mode(relay_mode) .alpns(vec![ALPN.to_vec()]) .bind() - .await?; + .await + .unwrap(); - Ok(Self { + Self { secret_key, endpoint, - }) + } } fn tracked_endpoints(&self) -> Vec { @@ -3624,39 +3724,35 @@ mod tests { } }) .await - .context("failed to connect nodes")?; + .context("timeout")?; info!("all nodes meshed"); Ok(tasks) } #[instrument(skip_all, fields(me = %ep.endpoint.node_id().fmt_short()))] - async fn echo_receiver(ep: MagicStack, loss: ExpectedLoss) -> Result<()> { + async fn echo_receiver(ep: MagicStack, loss: ExpectedLoss) -> Result { info!("accepting conn"); let conn = ep.endpoint.accept().await.expect("no conn"); info!("connecting"); - let conn = conn.await.context("[receiver] connecting")?; + let conn = conn.await.context("connecting")?; info!("accepting bi"); - let (mut send_bi, mut recv_bi) = - conn.accept_bi().await.context("[receiver] accepting bi")?; + let (mut send_bi, mut recv_bi) = conn.accept_bi().await.context("accept bi")?; info!("reading"); let val = recv_bi .read_to_end(usize::MAX) .await - .context("[receiver] reading to end")?; + .context("read to end")?; info!("replying"); for chunk in val.chunks(12) { - send_bi - .write_all(chunk) - .await - .context("[receiver] sending chunk")?; + send_bi.write_all(chunk).await.context("write all")?; } info!("finishing"); - send_bi.finish().context("[receiver] finishing")?; - send_bi.stopped().await.context("[receiver] stopped")?; + send_bi.finish().context("finish")?; + send_bi.stopped().await.context("stopped")?; let stats = conn.stats(); info!("stats: {:#?}", stats); @@ -3682,27 +3778,26 @@ mod tests { dest_id: PublicKey, msg: &[u8], loss: ExpectedLoss, - ) -> Result<()> { + ) -> Result { info!("connecting to {}", dest_id.fmt_short()); let dest = NodeAddr::new(dest_id); - let conn = ep - .endpoint - .connect(dest, ALPN) - .await - .context("[sender] connect")?; + let conn = ep.endpoint.connect(dest, ALPN).await?; info!("opening bi"); - let (mut send_bi, mut recv_bi) = conn.open_bi().await.context("[sender] open bi")?; + let (mut send_bi, mut recv_bi) = conn.open_bi().await.context("open bi")?; info!("writing message"); - send_bi.write_all(msg).await.context("[sender] write all")?; + send_bi.write_all(msg).await.context("write all")?; info!("finishing"); - send_bi.finish().context("[sender] finish")?; - send_bi.stopped().await.context("[sender] stopped")?; + send_bi.finish().context("finish")?; + send_bi.stopped().await.context("stopped")?; info!("reading_to_end"); - let val = recv_bi.read_to_end(usize::MAX).await.context("[sender]")?; + let val = recv_bi + .read_to_end(usize::MAX) + .await + .context("read to end")?; assert_eq!( val, msg, @@ -3775,9 +3870,9 @@ mod tests { #[tokio::test(flavor = "multi_thread")] #[traced_test] - async fn test_two_devices_roundtrip_quinn_magic() -> Result<()> { - let m1 = MagicStack::new(RelayMode::Disabled).await?; - let m2 = MagicStack::new(RelayMode::Disabled).await?; + async fn test_two_devices_roundtrip_quinn_magic() -> Result { + let m1 = MagicStack::new(RelayMode::Disabled).await; + let m2 = MagicStack::new(RelayMode::Disabled).await; let _guard = mesh_stacks(vec![m1.clone(), m2.clone()]).await?; @@ -3810,10 +3905,9 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_regression_network_change_rebind_wakes_connection_driver( - ) -> testresult::TestResult { - let m1 = MagicStack::new(RelayMode::Disabled).await?; - let m2 = MagicStack::new(RelayMode::Disabled).await?; + async fn test_regression_network_change_rebind_wakes_connection_driver() -> n0_snafu::Result { + let m1 = MagicStack::new(RelayMode::Disabled).await; + let m2 = MagicStack::new(RelayMode::Disabled).await; println!("Net change"); m1.endpoint.magic_sock().force_network_change(true).await; @@ -3826,11 +3920,11 @@ mod tests { async move { while let Some(incoming) = endpoint.accept().await { println!("Incoming first conn!"); - let conn = incoming.await?; + let conn = incoming.await.e()?; conn.closed().await; } - testresult::TestResult::Ok(()) + Ok::<_, n0_snafu::Error>(()) } })); @@ -3849,19 +3943,20 @@ mod tests { #[tokio::test(flavor = "multi_thread")] #[traced_test] - async fn test_two_devices_roundtrip_network_change() -> Result<()> { + async fn test_two_devices_roundtrip_network_change() -> Result { time::timeout( Duration::from_secs(90), test_two_devices_roundtrip_network_change_impl(), ) - .await? + .await + .context("timeout")? } /// Same structure as `test_two_devices_roundtrip_quinn_magic`, but interrupts regularly /// with (simulated) network changes. - async fn test_two_devices_roundtrip_network_change_impl() -> Result<()> { - let m1 = MagicStack::new(RelayMode::Disabled).await?; - let m2 = MagicStack::new(RelayMode::Disabled).await?; + async fn test_two_devices_roundtrip_network_change_impl() -> Result { + let m1 = MagicStack::new(RelayMode::Disabled).await; + let m2 = MagicStack::new(RelayMode::Disabled).await; let _guard = mesh_stacks(vec![m1.clone(), m2.clone()]).await?; @@ -3957,12 +4052,12 @@ mod tests { #[tokio::test(flavor = "multi_thread")] #[traced_test] - async fn test_two_devices_setup_teardown() -> Result<()> { + async fn test_two_devices_setup_teardown() -> Result { for i in 0..10 { println!("-- round {i}"); println!("setting up magic stack"); - let m1 = MagicStack::new(RelayMode::Disabled).await?; - let m2 = MagicStack::new(RelayMode::Disabled).await?; + let m1 = MagicStack::new(RelayMode::Disabled).await; + let m2 = MagicStack::new(RelayMode::Disabled).await; let _guard = mesh_stacks(vec![m1.clone(), m2.clone()]).await?; @@ -4071,10 +4166,7 @@ mod tests { /// /// Use [`magicsock_connect`] to establish connections. #[instrument(name = "ep", skip_all, fields(me = secret_key.public().fmt_short()))] - async fn magicsock_ep( - secret_key: SecretKey, - tls_auth: tls::Authentication, - ) -> anyhow::Result { + async fn magicsock_ep(secret_key: SecretKey, tls_auth: tls::Authentication) -> Result { let quic_server_config = tls::TlsConfig::new(tls_auth, secret_key.clone()) .make_server_config(vec![ALPN.to_vec()], true); let mut server_config = ServerConfig::with_crypto(Arc::new(quic_server_config)); @@ -4146,12 +4238,14 @@ mod tests { tls::TlsConfig::new(tls_auth, ep_secret_key.clone()).make_client_config(alpns, true); let mut client_config = quinn::ClientConfig::new(Arc::new(quic_client_config)); client_config.transport_config(transport_config); - let connect = ep.connect_with( - client_config, - mapped_addr.private_socket_addr(), - &tls::name::encode(node_id), - )?; - let connection = connect.await?; + let connect = ep + .connect_with( + client_config, + mapped_addr.private_socket_addr(), + &tls::name::encode(node_id), + ) + .context("connect")?; + let connection = connect.await.e()?; Ok(connection) } @@ -4197,8 +4291,12 @@ mod tests { // This needs an accept task let accept_task = tokio::spawn({ async fn accept(ep: quinn::Endpoint) -> Result<()> { - let incoming = ep.accept().await.ok_or(anyhow!("no incoming"))?; - let _conn = incoming.accept()?.await?; + let incoming = ep.accept().await.context("no incoming")?; + let _conn = incoming + .accept() + .context("accept")? + .await + .context("connecting")?; // Keep this connection alive for a while tokio::time::sleep(Duration::from_secs(10)).await; @@ -4275,10 +4373,14 @@ mod tests { // We need a task to accept the connection. let accept_task = tokio::spawn({ async fn accept(ep: quinn::Endpoint) -> Result<()> { - let incoming = ep.accept().await.ok_or(anyhow!("no incoming"))?; - let conn = incoming.accept()?.await?; - let mut stream = conn.accept_uni().await?; - stream.read_to_end(1 << 16).await?; + let incoming = ep.accept().await.context("no incoming")?; + let conn = incoming + .accept() + .context("accept")? + .await + .context("connecting")?; + let mut stream = conn.accept_uni().await.context("accept uni")?; + stream.read_to_end(1 << 16).await.context("read to end")?; info!("accept finished"); Ok(()) } @@ -4430,8 +4532,8 @@ mod tests { } #[tokio::test] - async fn test_add_node_addr() -> Result<()> { - let stack = MagicStack::new(RelayMode::Default).await?; + async fn test_add_node_addr() -> Result { + let stack = MagicStack::new(RelayMode::Default).await; let mut rng = rand::thread_rng(); assert_eq!(stack.endpoint.magic_sock().node_map.node_count(), 0); @@ -4447,12 +4549,15 @@ mod tests { .magic_sock() .add_node_addr(empty_addr, node_map::Source::App) .unwrap_err(); - assert!(err.to_string().contains("empty addressing info")); + assert!(err + .to_string() + .to_lowercase() + .contains("empty addressing info")); // relay url only let addr = NodeAddr { node_id: SecretKey::generate(&mut rng).public(), - relay_url: Some("http://my-relay.com".parse()?), + relay_url: Some("http://my-relay.com".parse().unwrap()), direct_addresses: Default::default(), }; stack @@ -4465,7 +4570,7 @@ mod tests { let addr = NodeAddr { node_id: SecretKey::generate(&mut rng).public(), relay_url: None, - direct_addresses: ["127.0.0.1:1234".parse()?].into_iter().collect(), + direct_addresses: ["127.0.0.1:1234".parse().unwrap()].into_iter().collect(), }; stack .endpoint @@ -4476,8 +4581,8 @@ mod tests { // both let addr = NodeAddr { node_id: SecretKey::generate(&mut rng).public(), - relay_url: Some("http://my-relay.com".parse()?), - direct_addresses: ["127.0.0.1:1234".parse()?].into_iter().collect(), + relay_url: Some("http://my-relay.com".parse().unwrap()), + direct_addresses: ["127.0.0.1:1234".parse().unwrap()].into_iter().collect(), }; stack .endpoint diff --git a/iroh/src/magicsock/node_map.rs b/iroh/src/magicsock/node_map.rs index 5b239b0a385..6392ebeb0c4 100644 --- a/iroh/src/magicsock/node_map.rs +++ b/iroh/src/magicsock/node_map.rs @@ -303,18 +303,13 @@ impl NodeMap { .collect() } - /// Returns a [`Watcher`] for given node's [`ConnectionType`]. + /// Returns a [`watcher::Direct`] for given node's [`ConnectionType`]. /// /// # Errors /// - /// Will return an error if there is not an entry in the [`NodeMap`] for + /// Will return `None` if there is not an entry in the [`NodeMap`] for /// the `node_id` - /// - /// [`Watcher`]: crate::watcher::Watcher - pub(super) fn conn_type( - &self, - node_id: NodeId, - ) -> anyhow::Result> { + pub(super) fn conn_type(&self, node_id: NodeId) -> Option> { self.inner.lock().expect("poisoned").conn_type(node_id) } @@ -508,13 +503,11 @@ impl NodeMapInner { /// /// # Errors /// - /// Will return an error if there is not an entry in the [`NodeMap`] for + /// Will return `None` if there is not an entry in the [`NodeMap`] for /// the `public_key` - fn conn_type(&self, node_id: NodeId) -> anyhow::Result> { - match self.get(NodeStateKey::NodeId(node_id)) { - Some(ep) => Ok(ep.conn_type()), - None => anyhow::bail!("No endpoint for {node_id:?} found"), - } + fn conn_type(&self, node_id: NodeId) -> Option> { + self.get(NodeStateKey::NodeId(node_id)) + .map(|ep| ep.conn_type()) } fn handle_pong(&mut self, sender: NodeId, src: &DiscoMessageSource, pong: Pong) { diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/relay_actor.rs index d603724fe19..56ee46b50db 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/relay_actor.rs @@ -37,13 +37,12 @@ use std::{ }, }; -use anyhow::{anyhow, Context, Result}; use backon::{Backoff, BackoffBuilder, ExponentialBuilder}; use bytes::{Bytes, BytesMut}; use iroh_base::{NodeId, PublicKey, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, - client::{Client, ReceivedMessage, SendMessage}, + client::{Client, ConnectError, ReceivedMessage, RecvError, SendError, SendMessage}, PingTracker, MAX_PACKET_SIZE, }; use n0_future::{ @@ -51,6 +50,8 @@ use n0_future::{ time::{self, Duration, Instant, MissedTickBehavior}, FuturesUnorderedBounded, SinkExt, StreamExt, }; +use nested_enum_utils::common_fields; +use snafu::{IntoError, ResultExt, Snafu}; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, event, info_span, instrument, trace, warn, Instrument, Level}; @@ -216,14 +217,61 @@ struct RelayConnectionOptions { } /// Possible reasons for a failed relay connection. -#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] enum RelayConnectionError { - #[error("Failed to connect to relay server: {0:#}")] - Connecting(#[source] anyhow::Error), - #[error("Failed to handshake with relay server: {0:#}")] - Handshake(#[source] anyhow::Error), - #[error("Lost connection to relay server: {0:#}")] - Established(#[source] anyhow::Error), + #[snafu(display("Failed to connect to relay server"))] + Dial { source: DialError }, + #[snafu(display("Failed to handshake with relay server"))] + Handshake { source: RunError }, + #[snafu(display("Lost connection to relay server"))] + Established { source: RunError }, +} + +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +enum RunError { + #[snafu(display("Send timeout"))] + SendTimeout {}, + #[snafu(display("Ping timeout"))] + PingTimeout {}, + #[snafu(display("Local IP no longer valid"))] + LocalIpInvalid {}, + #[snafu(display("No local address"))] + LocalAddrMissing {}, + #[snafu(display("Stream closed by server."))] + StreamClosedServer {}, + #[snafu(display("Client stream read failed"))] + ClientStreamRead { source: RecvError }, + #[snafu(display("Client stream write failed"))] + ClientStreamWrite { source: SendError }, +} + +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +enum DialError { + #[snafu(display("timeout trying to establish a connection"))] + Timeout {}, + #[snafu(display("unable to connect"))] + Connect { + #[snafu(source(from(ConnectError, Box::new)))] + source: Box, + }, } impl ActiveRelayActor { @@ -287,7 +335,7 @@ impl ActiveRelayActor { /// The main actor run loop. /// /// Primarily switches between the dialing and connected states. - async fn run(mut self) -> Result<()> { + async fn run(mut self) { // TODO(frando): decide what this metric means, it's either wrong here or in node_state.rs. // From the existing description, it is wrong here. // self.metrics.num_relay_conns_added.inc(); @@ -297,14 +345,17 @@ impl ActiveRelayActor { while let Err(err) = self.run_once().await { warn!("{err}"); match err { - RelayConnectionError::Connecting(_) | RelayConnectionError::Handshake(_) => { + RelayConnectionError::Dial { .. } | RelayConnectionError::Handshake { .. } => { // If dialing failed, or if the relay connection failed before we received a pong, // we wait an exponentially increasing time until we attempt to reconnect again. - let delay = backoff.next().context("Retries exceeded")?; - debug!("Retry in {delay:?}"); + let Some(delay) = backoff.next() else { + warn!("retries exceeded"); + break; + }; + debug!("retry in {delay:?}"); time::sleep(delay).await; } - RelayConnectionError::Established(_) => { + RelayConnectionError::Established { .. } => { // If the relay connection remained established long enough so that we received a pong // from the relay server, we reset the backoff and attempt to reconnect immediately. backoff = Self::build_backoff(); @@ -315,7 +366,6 @@ impl ActiveRelayActor { // TODO(frando): decide what this metric means, it's either wrong here or in node_state.rs. // From the existing description, it is wrong here. // self.metrics.num_relay_conns_removed.inc(); - Ok(()) } fn build_backoff() -> impl Backoff { @@ -334,8 +384,7 @@ impl ActiveRelayActor { /// be retried with a backoff. async fn run_once(&mut self) -> Result<(), RelayConnectionError> { let client = match self.run_dialing().instrument(info_span!("dialing")).await { - Some(Ok(client)) => client, - Some(Err(err)) => return Err(RelayConnectionError::Connecting(err)), + Some(client_res) => client_res.context(DialSnafu)?, None => return Ok(()), }; self.run_connected(client) @@ -365,7 +414,7 @@ impl ActiveRelayActor { /// /// Returns `None` if the actor needs to shut down. Returns `Some(Ok(client))` when the /// connection is established, and `Some(Err(err))` if dialing the relay failed. - async fn run_dialing(&mut self) -> Option> { + async fn run_dialing(&mut self) -> Option> { debug!("Actor loop: connecting to relay."); // We regularly flush the relay_datagrams_send queue so it is not full of stale @@ -452,13 +501,13 @@ impl ActiveRelayActor { /// connections. It currently does not ever return `Err` as the retries continue /// forever. // This is using `impl Future` to return a future without a reference to self. - fn dial_relay(&self) -> impl Future> { + fn dial_relay(&self) -> impl Future> { let client_builder = self.relay_client_builder.clone(); async move { match time::timeout(CONNECT_TIMEOUT, client_builder.connect()).await { Ok(Ok(client)) => Ok(client), - Ok(Err(err)) => Err(err), - Err(_) => Err(anyhow!("Connecting timed out after {CONNECT_TIMEOUT:?}")), + Ok(Err(err)) => Err(ConnectSnafu.into_error(err)), + Err(_) => Err(TimeoutSnafu.build()), } } } @@ -479,7 +528,8 @@ impl ActiveRelayActor { home_relay = self.is_home_relay, ); - let (mut client_stream, mut client_sink) = client.split(); + let (mut client_stream, client_sink) = client.split(); + let mut client_sink = client_sink.sink_map_err(|e| ClientStreamWriteSnafu.into_error(e)); let mut state = ConnectedRelayState { ping_tracker: PingTracker::default(), @@ -524,7 +574,7 @@ impl ActiveRelayActor { } } _ = state.ping_tracker.timeout() => { - break Err(anyhow!("Ping timeout")); + break Err(PingTimeoutSnafu.build()); } _ = ping_interval.tick() => { let data = state.ping_tracker.new_ping(); @@ -547,8 +597,8 @@ impl ActiveRelayActor { let fut = client_sink.send(SendMessage::Ping(data)); self.run_sending(fut, &mut state, &mut client_stream).await?; } - Some(_) => break Err(anyhow!("Local IP no longer valid")), - None => break Err(anyhow!("No local addr, reconnecting")), + Some(_) => break Err(LocalIpInvalidSnafu.build()), + None => break Err(LocalAddrMissingSnafu.build()), } } #[cfg(test)] @@ -601,7 +651,7 @@ impl ActiveRelayActor { } msg = client_stream.next() => { let Some(msg) = msg else { - break Err(anyhow!("Stream closed by server.")); + break Err(StreamClosedServerSnafu.build()); }; match msg { Ok(msg) => { @@ -609,7 +659,7 @@ impl ActiveRelayActor { // reset the ping timer, we have just received a message ping_interval.reset(); }, - Err(err) => break Err(anyhow!("Client stream read error: {err:#}")), + Err(err) => break Err(ClientStreamReadSnafu.into_error(err)), } } _ = &mut self.inactive_timeout, if !self.is_home_relay => { @@ -695,9 +745,9 @@ impl ActiveRelayActor { /// the actor should shut down, consult the [`ActiveRelayActor::stop_token`] and /// [`ActiveRelayActor::inactive_timeout`] for this, or the send was successful. #[instrument(name = "tx", skip_all)] - async fn run_sending>( + async fn run_sending( &mut self, - sending_fut: impl Future>, + sending_fut: impl Future>, state: &mut ConnectedRelayState, client_stream: &mut iroh_relay::client::ClientStream, ) -> Result<(), RelayConnectionError> { @@ -713,7 +763,7 @@ impl ActiveRelayActor { break Ok(()); } _ = &mut timeout => { - break Err(anyhow!("Send timeout")); + break Err(SendTimeoutSnafu.build()); } msg = self.prio_inbox.recv() => { let Some(msg) = msg else { @@ -730,20 +780,20 @@ impl ActiveRelayActor { res = &mut sending_fut => { match res { Ok(_) => break Ok(()), - Err(err) => break Err(err.into()), + Err(err) => break Err(err), } } _ = state.ping_tracker.timeout() => { - break Err(anyhow!("Ping timeout")); + break Err(PingTimeoutSnafu.build()); } // No need to read the inbox or datagrams to send. msg = client_stream.next() => { let Some(msg) = msg else { - break Err(anyhow!("Stream closed by server.")); + break Err(StreamClosedServerSnafu.build()); }; match msg { Ok(msg) => self.handle_relay_msg(msg, state), - Err(err) => break Err(anyhow!("Client stream read error: {err:#}")), + Err(err) => break Err(ClientStreamReadSnafu.into_error(err)), } } _ = &mut self.inactive_timeout, if !self.is_home_relay => { @@ -782,11 +832,11 @@ struct ConnectedRelayState { } impl ConnectedRelayState { - fn map_err(&self, error: anyhow::Error) -> RelayConnectionError { + fn map_err(&self, error: RunError) -> RelayConnectionError { if self.established { - RelayConnectionError::Established(error) + EstablishedSnafu.into_error(error) } else { - RelayConnectionError::Handshake(error) + HandshakeSnafu.into_error(error) } } } @@ -1063,10 +1113,7 @@ impl RelayActor { let actor = ActiveRelayActor::new(opts); self.active_relay_tasks.spawn( async move { - // TODO: Make the actor itself infallible. - if let Err(err) = actor.run().await { - warn!("actor error: {err:#}"); - } + actor.run().await; } .instrument(span), ); @@ -1278,17 +1325,36 @@ impl Iterator for PacketSplitIter { #[cfg(test)] mod tests { - use anyhow::Context; - use iroh_base::SecretKey; + use std::{ + sync::{atomic::AtomicBool, Arc}, + time::Duration, + }; + + use bytes::Bytes; + use iroh_base::{NodeId, RelayUrl, SecretKey}; + use iroh_relay::PingTracker; use n0_future::future; + use n0_snafu::{Error, Result, ResultExt}; use smallvec::smallvec; - use testresult::TestResult; - use tokio_util::task::AbortOnDropHandle; - use tracing::info; + use tokio::sync::{mpsc, oneshot}; + use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; + use tracing::{info, info_span, Instrument}; use tracing_test::traced_test; - use super::*; - use crate::{dns::DnsResolver, test_utils}; + use super::{ + ActiveRelayActor, ActiveRelayActorOptions, ActiveRelayMessage, ActiveRelayPrioMessage, + RelayConnectionOptions, RelaySendItem, MAX_PACKET_SIZE, + }; + use crate::{ + dns::DnsResolver, + magicsock::{ + relay_actor::{ + PacketizeIter, RELAY_INACTIVE_CLEANUP_TIME, UNDELIVERABLE_DATAGRAM_TIMEOUT, + }, + RelayDatagramRecvQueue, RelayRecvDatagram, + }, + test_utils, + }; #[test] fn test_packetize_iter() { @@ -1332,7 +1398,7 @@ mod tests { relay_datagrams_send: mpsc::Receiver, relay_datagrams_recv: Arc, span: tracing::Span, - ) -> AbortOnDropHandle> { + ) -> AbortOnDropHandle<()> { let opts = ActiveRelayActorOptions { url, prio_inbox_: prio_inbox_rx, @@ -1425,7 +1491,7 @@ mod tests { tokio::time::timeout(Duration::from_secs(10), async move { loop { let res = tokio::time::timeout(UNDELIVERABLE_DATAGRAM_TIMEOUT, async { - tx.send(item.clone()).await?; + tx.send(item.clone()).await.context("send item")?; let RelayRecvDatagram { url: _, src: _, @@ -1434,7 +1500,7 @@ mod tests { assert_eq!(buf.as_ref(), item.datagrams[0]); - Ok::<_, anyhow::Error>(()) + Ok::<_, Error>(()) }) .await; if res.is_ok() { @@ -1449,7 +1515,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_active_relay_reconnect() -> TestResult { + async fn test_active_relay_reconnect() -> Result { let (_relay_map, relay_url, _server) = test_utils::run_relay_server().await?; let (peer_node, _echo_node_task) = start_echo_node(relay_url.clone()); @@ -1486,18 +1552,29 @@ mod tests { // Now ask to check the connection, triggering a ping but no reconnect. let (tx, rx) = oneshot::channel(); - inbox_tx.send(ActiveRelayMessage::GetLocalAddr(tx)).await?; - let local_addr = rx.await?.context("no local addr")?; + inbox_tx + .send(ActiveRelayMessage::GetLocalAddr(tx)) + .await + .context("send get local addr msg")?; + + let local_addr = rx + .await + .context("wait for local addr msg")? + .context("no local addr")?; info!(?local_addr, "check connection with addr"); inbox_tx .send(ActiveRelayMessage::CheckConnection(vec![local_addr.ip()])) - .await?; + .await + .context("send check connection message")?; // Sync the ActiveRelayActor. Ping blocks it and we want to be sure it has handled // another inbox message before continuing. let (tx, rx) = oneshot::channel(); - inbox_tx.send(ActiveRelayMessage::GetLocalAddr(tx)).await?; - rx.await?; + inbox_tx + .send(ActiveRelayMessage::GetLocalAddr(tx)) + .await + .context("send get local addr msg")?; + rx.await.context("recv send local addr msg")?; // Echo should still work. info!("second echo"); @@ -1513,7 +1590,8 @@ mod tests { info!("check connection"); inbox_tx .send(ActiveRelayMessage::CheckConnection(Vec::new())) - .await?; + .await + .context("send check connection msg")?; // Give some time to reconnect, mostly to sort logs rather than functional. tokio::time::sleep(Duration::from_millis(10)).await; @@ -1529,14 +1607,14 @@ mod tests { // Shut down the actor. cancel_token.cancel(); - task.await??; + task.await.context("wait for task to finish")?; Ok(()) } #[tokio::test] #[traced_test] - async fn test_active_relay_inactive() -> TestResult { + async fn test_active_relay_inactive() -> Result { let (_relay_map, relay_url, _server) = test_utils::run_relay_server().await?; let secret_key = SecretKey::from_bytes(&[1u8; 32]); @@ -1570,7 +1648,8 @@ mod tests { } } }) - .await?; + .await + .context("timeout")?; // We now have an idling ActiveRelayActor. If we advance time just a little it // should stay alive. diff --git a/iroh/src/net_report.rs b/iroh/src/net_report.rs index 2aa1d50a7af..0919fc1ea23 100644 --- a/iroh/src/net_report.rs +++ b/iroh/src/net_report.rs @@ -17,7 +17,6 @@ use std::{ sync::Arc, }; -use anyhow::{anyhow, Result}; use bytes::Bytes; use iroh_base::RelayUrl; #[cfg(not(wasm_browser))] @@ -27,8 +26,11 @@ use n0_future::{ task::{self, AbortOnDropHandle}, time::{Duration, Instant}, }; +use nested_enum_utils::common_fields; #[cfg(not(wasm_browser))] use netwatch::UdpSocket; +use reportgen::ActorRunError; +use snafu::Snafu; use tokio::sync::{self, mpsc, oneshot}; use tracing::{debug, error, info_span, trace, warn, Instrument}; @@ -248,7 +250,7 @@ impl Client { #[cfg(not(wasm_browser))] dns_resolver: DnsResolver, #[cfg(not(wasm_browser))] ip_mapped_addrs: Option, metrics: Arc, - ) -> Result { + ) -> Self { let mut actor = Actor::new( #[cfg(not(wasm_browser))] port_mapper, @@ -257,16 +259,16 @@ impl Client { #[cfg(not(wasm_browser))] ip_mapped_addrs, metrics, - )?; + ); let addr = actor.addr(); let task = task::spawn( async move { actor.run().await }.instrument(info_span!("net_report.actor")), ); let drop_guard = AbortOnDropHandle::new(task); - Ok(Client { + Client { addr, _drop_guard: Arc::new(drop_guard), - }) + } } /// Returns a new address to send messages to this actor. @@ -306,7 +308,7 @@ impl Client { #[cfg(not(wasm_browser))] stun_sock_v4: Option>, #[cfg(not(wasm_browser))] stun_sock_v6: Option>, #[cfg(not(wasm_browser))] quic_config: Option, - ) -> Result> { + ) -> Result, ReportError> { #[cfg(not(wasm_browser))] let opts = Options::default() .stun_v4(stun_sock_v4) @@ -314,10 +316,11 @@ impl Client { .quic_config(quic_config); #[cfg(wasm_browser)] let opts = Options::default(); - let rx = self.get_report_channel(relay_map.clone(), opts).await?; + + let rx = self.get_report_channel(relay_map, opts).await?; match rx.await { Ok(res) => res, - Err(_) => Err(anyhow!("channel closed, actor awol")), + Err(_) => Err(ActorGoneSnafu.build()), } } @@ -326,11 +329,15 @@ impl Client { /// It may not be called concurrently with itself, `&mut self` takes care of that. /// /// Look at [`Options`] for the different configuration options. - pub async fn get_report(&mut self, relay_map: RelayMap, opts: Options) -> Result> { + pub async fn get_report( + &mut self, + relay_map: RelayMap, + opts: Options, + ) -> Result, ReportError> { let rx = self.get_report_channel(relay_map, opts).await?; match rx.await { Ok(res) => res, - Err(_) => Err(anyhow!("channel closed, actor awol")), + Err(_) => Err(ActorGoneSnafu.build()), } } @@ -341,7 +348,7 @@ impl Client { &mut self, relay_map: RelayMap, opts: Options, - ) -> Result>>> { + ) -> Result, ReportError>>, ReportError> { let (tx, rx) = oneshot::channel(); self.addr .send(Message::RunCheck { @@ -349,7 +356,8 @@ impl Client { opts, response_tx: tx, }) - .await?; + .await + .map_err(|_| ActorGoneSnafu.build())?; Ok(rx) } } @@ -378,12 +386,12 @@ pub(crate) enum Message { /// Options for the report opts: Options, /// Channel to receive the response. - response_tx: oneshot::Sender>>, + response_tx: oneshot::Sender, ReportError>>, }, /// A report produced by the [`reportgen`] actor. ReportReady { report: Box }, /// The [`reportgen`] actor failed to produce a report. - ReportAborted { err: anyhow::Error }, + ReportAborted { reason: ActorRunError }, /// An incoming STUN packet to parse. StunPacket { /// The raw UDP payload. @@ -494,10 +502,10 @@ impl Actor { #[cfg(not(wasm_browser))] dns_resolver: DnsResolver, #[cfg(not(wasm_browser))] ip_mapped_addrs: Option, metrics: Arc, - ) -> Result { + ) -> Self { // TODO: consider an instrumented flume channel so we have metrics. let (sender, receiver) = mpsc::channel(32); - Ok(Self { + Self { receiver, sender, reports: Default::default(), @@ -510,7 +518,7 @@ impl Actor { #[cfg(not(wasm_browser))] ip_mapped_addrs, metrics, - }) + } } /// Returns the channel to send messages to the actor. @@ -540,7 +548,7 @@ impl Actor { Message::ReportReady { report } => { self.handle_report_ready(*report); } - Message::ReportAborted { err } => { + Message::ReportAborted { reason: err } => { self.handle_report_aborted(err); } Message::StunPacket { payload, from_addr } => { @@ -562,7 +570,7 @@ impl Actor { &mut self, relay_map: RelayMap, opts: Options, - response_tx: oneshot::Sender>>, + response_tx: oneshot::Sender, ReportError>>, ) { let protocols = opts.to_protocols(); #[cfg(not(wasm_browser))] @@ -576,11 +584,7 @@ impl Actor { }; trace!("Attempting probes for protocols {protocols:#?}"); if self.current_report_run.is_some() { - response_tx - .send(Err(anyhow!( - "ignoring RunCheck request: reportgen actor already running" - ))) - .ok(); + response_tx.send(Err(AlreadyRunningSnafu.build())).ok(); return; } @@ -629,10 +633,10 @@ impl Actor { } } - fn handle_report_aborted(&mut self, err: anyhow::Error) { + fn handle_report_aborted(&mut self, reason: ActorRunError) { self.in_flight_stun_requests.clear(); if let Some(ReportRun { report_tx, .. }) = self.current_report_run.take() { - report_tx.send(Err(err.context("report aborted"))).ok(); + report_tx.send(Err(AbortSnafu { reason }.build())).ok(); } } @@ -784,7 +788,24 @@ struct ReportRun { /// The handle of the [`reportgen`] actor, cancels the actor on drop. _reportgen: reportgen::Client, /// Where to send the completed report. - report_tx: oneshot::Sender>>, + report_tx: oneshot::Sender, ReportError>>, +} + +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ReportError { + #[snafu(display("Report aborted early"))] + Abort { reason: ActorRunError }, + #[snafu(display("Report generation is already running"))] + AlreadyRunning {}, + #[snafu(display("Internal actor is gone"))] + ActorGone {}, } /// Test if IPv6 works at all, or if it's been hard disabled at the OS level. @@ -801,7 +822,6 @@ fn os_has_ipv6() -> bool { #[cfg(not(wasm_browser))] pub(crate) mod stun_utils { - use anyhow::Context as _; use netwatch::IpFamily; use tokio_util::sync::CancellationToken; @@ -858,19 +878,32 @@ pub(crate) mod stun_utils { Some(sock) } + #[derive(Debug, Snafu)] + enum RecvStunError { + #[snafu(transparent)] + Recv { source: std::io::Error }, + #[snafu(display("Internal actor is gone"))] + ActorGone, + } + /// Receive STUN response from a UDP socket, pass it to the actor. - async fn recv_stun_once(sock: &UdpSocket, buf: &mut [u8], actor_addr: &Addr) -> Result<()> { - let (count, mut from_addr) = sock - .recv_from(buf) - .await - .context("Error reading from stun socket")?; + async fn recv_stun_once( + sock: &UdpSocket, + buf: &mut [u8], + actor_addr: &Addr, + ) -> Result<(), RecvStunError> { + let (count, mut from_addr) = sock.recv_from(buf).await?; + let payload = &buf[..count]; from_addr.set_ip(from_addr.ip().to_canonical()); let msg = Message::StunPacket { payload: Bytes::from(payload.to_vec()), from_addr, }; - actor_addr.send(msg).await.context("actor stopped") + actor_addr + .send(msg) + .await + .map_err(|_| ActorGoneSnafu.build()) } } @@ -915,23 +948,21 @@ mod test_utils { #[cfg(test)] mod tests { - use std::net::Ipv4Addr; - use bytes::BytesMut; + use n0_snafu::{Result, ResultExt}; use netwatch::IpFamily; use tokio_util::sync::CancellationToken; use tracing::info; use tracing_test::traced_test; use super::*; - use crate::net_report::{dns, ping::Pinger, stun_utils::bind_local_stun_socket}; + use crate::net_report::{dns, stun_utils::bind_local_stun_socket}; mod stun_utils { //! Utils for testing that expose a simple stun server. use std::{net::IpAddr, sync::Arc}; - use anyhow::Result; use iroh_base::RelayUrl; use iroh_relay::RelayNode; use tokio::{ @@ -987,12 +1018,15 @@ mod tests { /// Sets up a simple STUN server binding to `0.0.0.0:0`. /// /// See [`serve`] for more details. - pub(crate) async fn serve_v4() -> Result<(SocketAddr, StunStats, CleanupDropGuard)> { + pub(crate) async fn serve_v4() -> std::io::Result<(SocketAddr, StunStats, CleanupDropGuard)> + { serve(std::net::Ipv4Addr::UNSPECIFIED.into()).await } /// Sets up a simple STUN server. - pub(crate) async fn serve(ip: IpAddr) -> Result<(SocketAddr, StunStats, CleanupDropGuard)> { + pub(crate) async fn serve( + ip: IpAddr, + ) -> std::io::Result<(SocketAddr, StunStats, CleanupDropGuard)> { let stats = StunStats::default(); let pc = net::UdpSocket::bind((ip, 0)).await?; @@ -1060,12 +1094,12 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_basic() -> Result<()> { + async fn test_basic() -> Result { let (stun_addr, stun_stats, _cleanup_guard) = - stun_utils::serve("127.0.0.1".parse().unwrap()).await?; + stun_utils::serve("127.0.0.1".parse().unwrap()).await.e()?; let resolver = dns::tests::resolver(); - let mut client = Client::new(None, resolver.clone(), None, Default::default())?; + let mut client = Client::new(None, resolver.clone(), None, Default::default()); let dm = stun_utils::relay_map_of([stun_addr].into_iter()); // Note that the ProbePlan will change with each iteration. @@ -1101,60 +1135,8 @@ mod tests { Ok(()) } - #[tokio::test] - #[traced_test] - async fn test_udp_blocked() -> Result<()> { - // Create a "STUN server", which will never respond to anything. This is how UDP to - // the STUN server being blocked will look like from the client's perspective. - let blackhole = tokio::net::UdpSocket::bind("127.0.0.1:0").await?; - let stun_addr = blackhole.local_addr()?; - let dm = stun_utils::relay_map_of_opts([(stun_addr, false)].into_iter()); - - // Now create a client and generate a report. - let resolver = dns::tests::resolver(); - let mut client = Client::new(None, resolver.clone(), None, Default::default())?; - - let r = client.get_report_all(dm, None, None, None).await?; - let mut r: Report = (*r).clone(); - r.portmap_probe = None; - - // This test wants to ensure that the ICMP part of the probe works when UDP is - // blocked. Unfortunately on some systems we simply don't have permissions to - // create raw ICMP pings and we'll have to silently accept this test is useless (if - // we could, this would be a skip instead). - let pinger = Pinger::new(); - let can_ping = pinger.send(Ipv4Addr::LOCALHOST.into(), b"aa").await.is_ok(); - let want_icmpv4 = match can_ping { - true => Some(true), - false => None, - }; - - let want = Report { - // The ICMP probe sets the can_ping flag. - ipv4_can_send: can_ping, - // OS IPv6 test is irrelevant here, accept whatever the current machine has. - os_has_ipv6: r.os_has_ipv6, - // Captive portal test is irrelevant; accept what the current report has. - captive_portal: r.captive_portal, - // If we can ping we expect to have this. - icmpv4: want_icmpv4, - // If we had a pinger, we'll have some latencies filled in and a preferred relay - relay_latency: can_ping - .then(|| r.relay_latency.clone()) - .unwrap_or_default(), - preferred_relay: can_ping - .then_some(r.preferred_relay.clone()) - .unwrap_or_default(), - ..Default::default() - }; - - assert_eq!(r, want); - - Ok(()) - } - #[tokio::test(flavor = "current_thread", start_paused = true)] - async fn test_add_report_history_set_preferred_relay() -> Result<()> { + async fn test_add_report_history_set_preferred_relay() -> Result { fn relay_url(i: u16) -> RelayUrl { format!("http://{i}.com").parse().unwrap() } @@ -1315,7 +1297,7 @@ mod tests { let resolver = dns::tests::resolver(); for mut tt in tests { println!("test: {}", tt.name); - let mut actor = Actor::new(None, resolver.clone(), None, Default::default()).unwrap(); + let mut actor = Actor::new(None, resolver.clone(), None, Default::default()); for s in &mut tt.steps { // trigger the timer tokio::time::advance(Duration::from_secs(s.after)).await; @@ -1335,7 +1317,7 @@ mod tests { } #[tokio::test] - async fn test_hairpin() -> Result<()> { + async fn test_hairpin() -> Result { // Hairpinning is initiated after we discover our own IPv4 socket address (IP + // port) via STUN, so the test needs to have a STUN server and perform STUN over // IPv4 first. Hairpinning detection works by sending a STUN *request* to **our own @@ -1345,12 +1327,12 @@ mod tests { // can easily use to identify the packet. // Setup STUN server and create relay_map. - let (stun_addr, _stun_stats, _done) = stun_utils::serve_v4().await?; + let (stun_addr, _stun_stats, _done) = stun_utils::serve_v4().await.e()?; let dm = stun_utils::relay_map_of([stun_addr].into_iter()); dbg!(&dm); let resolver = dns::tests::resolver().clone(); - let mut client = Client::new(None, resolver, None, Default::default())?; + let mut client = Client::new(None, resolver, None, Default::default()); // Set up an external socket to send STUN requests from, this will be discovered as // our public socket address by STUN. We send back any packets received on this @@ -1359,7 +1341,7 @@ mod tests { // it to this socket, which is forwarnding it back to our net_report client, because // this dumb implementation just forwards anything even if it would be garbage. // Thus hairpinning detection will declare hairpinning to work. - let sock = UdpSocket::bind_local(netwatch::IpFamily::V4, 0)?; + let sock = UdpSocket::bind_local(netwatch::IpFamily::V4, 0).e()?; let sock = Arc::new(sock); info!(addr=?sock.local_addr().unwrap(), "Using local addr"); let task = { diff --git a/iroh/src/net_report/ip_mapped_addrs.rs b/iroh/src/net_report/ip_mapped_addrs.rs index 42a07f501bc..8a16db7a921 100644 --- a/iroh/src/net_report/ip_mapped_addrs.rs +++ b/iroh/src/net_report/ip_mapped_addrs.rs @@ -7,9 +7,11 @@ use std::{ }, }; +use snafu::Snafu; + /// Can occur when converting a [`SocketAddr`] to an [`IpMappedAddr`] -#[derive(Debug, thiserror::Error)] -#[error("Failed to convert")] +#[derive(Debug, Snafu)] +#[snafu(display("Failed to convert"))] pub struct IpMappedAddrError; /// A map fake Ipv6 address with an actual IP address. diff --git a/iroh/src/net_report/ping.rs b/iroh/src/net_report/ping.rs index 6ec3dff73c8..c887ea1e99c 100644 --- a/iroh/src/net_report/ping.rs +++ b/iroh/src/net_report/ping.rs @@ -6,22 +6,31 @@ use std::{ sync::{Arc, Mutex}, }; -use anyhow::{Context, Result}; use n0_future::time::Duration; +use nested_enum_utils::common_fields; +use snafu::{ResultExt, Snafu}; use surge_ping::{Client, Config, IcmpPacket, PingIdentifier, PingSequence, ICMP}; use tracing::debug; use crate::net_report::defaults::timeouts::DEFAULT_PINGER_TIMEOUT as DEFAULT_TIMEOUT; /// Whether this error was because we couldn't create a client or a send error. -#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub enum PingError { - /// Could not create client, probably bind error. - #[error("Error creating ping client")] - Client(#[from] anyhow::Error), + #[snafu(display("failed to create IPv4 ping client"))] + CreateClientIpv4 { source: std::io::Error }, + #[snafu(display("failed to create IPv6 ping client"))] + CreateClientIpv6 { source: std::io::Error }, /// Could not send ping. - #[error("Error sending ping")] - Ping(#[from] surge_ping::SurgeError), + #[snafu(display("failed to send ping"))] + Ping { source: surge_ping::SurgeError }, } /// Allows sending ICMP echo requests to a host in order to determine network latency. @@ -51,7 +60,7 @@ impl Pinger { /// /// We do this because it means we do not bind a socket until we really try to send a /// ping. It makes it more transparent to use the pinger. - fn get_client(&self, kind: ICMP) -> Result { + fn get_client(&self, kind: ICMP) -> Result { let client = match kind { ICMP::V4 => { let mut opt_client = self.0.client_v4.lock().expect("poisoned"); @@ -59,7 +68,7 @@ impl Pinger { Some(ref client) => client.clone(), None => { let cfg = Config::builder().kind(kind).build(); - let client = Client::new(&cfg).context("failed to create IPv4 pinger")?; + let client = Client::new(&cfg).context(CreateClientIpv4Snafu)?; *opt_client = Some(client.clone()); client } @@ -71,7 +80,7 @@ impl Pinger { Some(ref client) => client.clone(), None => { let cfg = Config::builder().kind(kind).build(); - let client = Client::new(&cfg).context("failed to create IPv6 pinger")?; + let client = Client::new(&cfg).context(CreateClientIpv6Snafu)?; *opt_client = Some(client.clone()); client } @@ -84,14 +93,18 @@ impl Pinger { /// Send a ping request with associated data, returning the perceived latency. pub async fn send(&self, addr: IpAddr, data: &[u8]) -> Result { let client = match addr { - IpAddr::V4(_) => self.get_client(ICMP::V4).map_err(PingError::Client)?, - IpAddr::V6(_) => self.get_client(ICMP::V6).map_err(PingError::Client)?, + IpAddr::V4(_) => self.get_client(ICMP::V4)?, + IpAddr::V6(_) => self.get_client(ICMP::V6)?, }; let ident = PingIdentifier(rand::random()); debug!(%addr, %ident, "Creating pinger"); let mut pinger = client.pinger(addr, ident).await; pinger.timeout(DEFAULT_TIMEOUT); // todo: timeout too large for net_report - match pinger.ping(PingSequence(0), data).await? { + match pinger + .ping(PingSequence(0), data) + .await + .context(PingSnafu)? + { (IcmpPacket::V4(packet), dur) => { debug!( "{} bytes from {}: icmp_seq={} ttl={:?} time={:0.2?}", @@ -138,12 +151,15 @@ mod tests { Ok(duration) => { assert!(!duration.is_zero()); } - Err(PingError::Client(err)) => { + Err( + PingError::CreateClientIpv4 { source, .. } + | PingError::CreateClientIpv6 { source, .. }, + ) => { // We don't have permission, too bad. - error!("no ping permissions: {err:#}"); + error!("no ping permissions: {source:#}"); } - Err(PingError::Ping(err)) => { - panic!("ping failed: {err:#}"); + Err(PingError::Ping { source, .. }) => { + panic!("ping failed: {source:#}"); } } @@ -151,12 +167,15 @@ mod tests { Ok(duration) => { assert!(!duration.is_zero()); } - Err(PingError::Client(err)) => { + Err( + PingError::CreateClientIpv4 { source, .. } + | PingError::CreateClientIpv6 { source, .. }, + ) => { // We don't have permission, too bad. - error!("no ping permissions: {err:#}"); + error!("no ping permissions: {source:#}"); } - Err(PingError::Ping(err)) => { - error!("ping failed, probably no IPv6 stack: {err:#}"); + Err(PingError::Ping { source, .. }) => { + error!("ping failed, probably no IPv6 stack: {source:#}"); } } } diff --git a/iroh/src/net_report/reportgen.rs b/iroh/src/net_report/reportgen.rs index ba27d9ea887..840b20f8ffb 100644 --- a/iroh/src/net_report/reportgen.rs +++ b/iroh/src/net_report/reportgen.rs @@ -25,10 +25,10 @@ use std::{ task::{Context, Poll}, }; -use anyhow::{anyhow, bail, Context as _, Result}; +use http::StatusCode; use iroh_base::RelayUrl; #[cfg(not(wasm_browser))] -use iroh_relay::dns::DnsResolver; +use iroh_relay::dns::{DnsError, DnsResolver}; use iroh_relay::{ defaults::{DEFAULT_RELAY_QUIC_PORT, DEFAULT_STUN_PORT}, http::RELAY_PROBE_PATH, @@ -45,6 +45,7 @@ use n0_future::{ #[cfg(not(wasm_browser))] use netwatch::{interfaces, UdpSocket}; use rand::seq::IteratorRandom; +use snafu::{IntoError, ResultExt, Snafu}; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, debug_span, error, info_span, trace, warn, Instrument, Span}; use url::Host; @@ -216,6 +217,30 @@ struct Actor { metrics: Arc, } +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ActorRunError { + #[snafu(display("Report generation timed out"))] + Timeout, + #[snafu(display("Client that requested the report is gone"))] + ClientGone, + #[snafu(display("Internal NetReport actor is gone"))] + ActorGone, + #[snafu(transparent)] + Probes { source: ProbesError }, +} + +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ProbesError { + #[snafu(display("Probe failed"))] + ProbeFailure { source: ProbeError }, + #[snafu(display("All probes failed"))] + AllProbesFailed, +} + impl Actor { fn addr(&self) -> Addr { Addr { @@ -228,7 +253,7 @@ impl Actor { Ok(_) => debug!("reportgen actor finished"), Err(err) => { self.net_report - .send(net_report::Message::ReportAborted { err }) + .send(net_report::Message::ReportAborted { reason: err }) .await .ok(); } @@ -247,7 +272,7 @@ impl Actor { /// - Receives actor messages (sent by those futures). /// - Updates the report, cancels unneeded futures. /// - Sends the report to the net_report actor. - async fn run_inner(&mut self) -> Result<()> { + async fn run_inner(&mut self) -> Result<(), ActorRunError> { #[cfg(not(wasm_browser))] let port_mapper = self.socket_state.port_mapper.is_some(); #[cfg(wasm_browser)] @@ -258,7 +283,7 @@ impl Actor { let mut port_mapping = self.prepare_portmapper_task(); let mut captive_task = self.prepare_captive_portal_task(); - let mut probes = self.spawn_probes_task().await?; + let mut probes = self.spawn_probes_task().await; let total_timer = time::sleep(OVERALL_REPORT_TIMEOUT); tokio::pin!(total_timer); @@ -275,7 +300,7 @@ impl Actor { biased; _ = &mut total_timer => { trace!("tick: total_timer expired"); - bail!("report timed out"); + return Err(TimeoutSnafu.build()); } _ = &mut probe_timer => { @@ -325,7 +350,9 @@ impl Actor { trace!("tick: msg recv: {:?}", msg); match msg { Some(msg) => self.handle_message(msg), - None => bail!("msg_rx closed, reportgen client must be dropped"), + None => { + return Err(ClientGoneSnafu.build()); + } } } } @@ -344,7 +371,8 @@ impl Actor { .send(net_report::Message::ReportReady { report: Box::new(self.report.clone()), }) - .await?; + .await + .map_err(|_| ActorGoneSnafu.build())?; Ok(()) } @@ -531,13 +559,14 @@ impl Actor { match captive_portal_check.await { Ok(Ok(found)) => Some(found), Ok(Err(err)) => { - let err: Result = err.downcast(); match err { - Ok(req_err) if req_err.is_connect() => { - debug!("check_captive_portal failed: {req_err:#}"); + CaptivePortalError::CreateReqwestClient { ref source } + | CaptivePortalError::HttpRequest { ref source } => { + if source.is_connect() { + debug!("check_captive_portal failed: {err:#}"); + } } - Ok(req_err) => warn!("check_captive_portal error: {req_err:#}"), - Err(any_err) => warn!("check_captive_portal error: {any_err:#}"), + _ => warn!("check_captive_portal error: {err:#}"), } None } @@ -571,7 +600,7 @@ impl Actor { /// failure permanent. Probes in a probe set are essentially retries. /// - Once there are [`ProbeReport`]s from enough nodes, all remaining probes are /// aborted. That is, the main actor loop stops polling them. - async fn spawn_probes_task(&mut self) -> Result>> { + async fn spawn_probes_task(&mut self) -> JoinSet> { #[cfg(not(wasm_browser))] let if_state = interfaces::State::new().await; #[cfg(not(wasm_browser))] @@ -640,15 +669,15 @@ impl Actor { while let Some(res) = set.join_next().await { match res { Ok(Ok(report)) => return Ok(report), - Ok(Err(ProbeError::Error(err, probe))) => { + Ok(Err(ProbeErrorWithProbe::Error(err, probe))) => { probe_proto = Some(probe.proto()); warn!(?probe, "probe failed: {:#}", err); continue; } - Ok(Err(ProbeError::AbortSet(err, probe))) => { + Ok(Err(ProbeErrorWithProbe::AbortSet(err, probe))) => { debug!(?probe, "probe set aborted: {:#}", err); set.abort_all(); - return Err(err); + return Err(ProbeFailureSnafu.into_error(err)); } Err(err) => { warn!("fatal probe set error, aborting: {:#}", err); @@ -657,14 +686,14 @@ impl Actor { } } warn!(?probe_proto, "no successful probes in ProbeSet"); - Err(anyhow!("All probes in ProbeSet failed")) + Err(AllProbesFailedSnafu.build()) } .instrument(info_span!("probe")), ); } self.outstanding_tasks.probes = true; - Ok(probes) + probes } } @@ -727,11 +756,62 @@ impl ProbeReport { /// the others in the set. So this allows an unsuccessful probe to cancel the remainder of /// the set or not. #[derive(Debug)] -enum ProbeError { +enum ProbeErrorWithProbe { /// Abort the current set. - AbortSet(anyhow::Error, Probe), + AbortSet(ProbeError, Probe), /// Continue the other probes in the set. - Error(anyhow::Error, Probe), + Error(ProbeError, Probe), +} + +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +pub enum ProbeError { + #[snafu(display("Client is gone"))] + ClientGone, + #[snafu(display("Probe is no longer useful"))] + NotUseful, + #[cfg(not(wasm_browser))] + #[snafu(display("Failed to retrieve the relay address"))] + GetRelayAddr { source: GetRelayAddrError }, + #[snafu(display("Failed to run stun probe"))] + Stun { source: StunError }, + #[snafu(display("Failed to run QUIC probe"))] + Quic { source: QuicError }, + #[cfg(not(wasm_browser))] + #[snafu(display("Failed to run ICMP probe"))] + Icmp { source: PingError }, +} + +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +pub enum StunError { + #[snafu(display("No UDP socket available"))] + NoSocket, + #[snafu(display("Stun channel is gone"))] + StunChannelGone, + #[snafu(display("Failed to send full STUN request"))] + SendFull, + #[snafu(display("Failed to send STUN request"))] + Send { source: std::io::Error }, +} + +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +pub enum QuicError { + #[snafu(display("No QUIC endpoint available"))] + NoEndpoint, + #[snafu(display("URL must have 'host' to use QUIC address discovery probes"))] + InvalidUrl, + #[snafu(display("Failed to create QUIC endpoint"))] + CreateClient { source: iroh_relay::quic::Error }, + #[snafu(display("Failed to get address and latency"))] + GetAddr { source: iroh_relay::quic::Error }, } /// Pieces needed to do QUIC address discovery. @@ -758,7 +838,7 @@ async fn run_probe( metrics: Arc, #[cfg(not(wasm_browser))] pinger: Pinger, #[cfg(not(wasm_browser))] socket_state: SocketState, -) -> Result { +) -> Result { if !probe.delay().is_zero() { trace!("delaying probe"); time::sleep(probe.delay()).await; @@ -776,17 +856,17 @@ async fn run_probe( { // this happens on shutdown or if the report is already finished debug!("Failed to check if probe would help: {err:#}"); - return Err(ProbeError::AbortSet(err.into(), probe.clone())); + return Err(ProbeErrorWithProbe::AbortSet( + probe_error::ClientGoneSnafu.build(), + probe.clone(), + )); } if !would_help_rx.await.map_err(|_| { - ProbeError::AbortSet( - anyhow!("ReportCheck actor dropped sender while waiting for ProbeWouldHelp response"), - probe.clone(), - ) + ProbeErrorWithProbe::AbortSet(probe_error::ClientGoneSnafu.build(), probe.clone()) })? { - return Err(ProbeError::AbortSet( - anyhow!("ReportCheck says probe set no longer useful"), + return Err(ProbeErrorWithProbe::AbortSet( + probe_error::NotUsefulSnafu.build(), probe, )); } @@ -794,8 +874,12 @@ async fn run_probe( #[cfg(not(wasm_browser))] let relay_addr = get_relay_addr(&socket_state.dns_resolver, &relay_node, probe.proto()) .await - .context("no relay node addr") - .map_err(|e| ProbeError::AbortSet(e, probe.clone()))?; + .map_err(|e| { + ProbeErrorWithProbe::AbortSet( + probe_error::GetRelayAddrSnafu.into_error(e), + probe.clone(), + ) + })?; let mut result = ProbeReport::new(probe.clone()); match probe { @@ -811,8 +895,8 @@ async fn run_probe( result = run_stun_probe(sock, relay_addr, net_report, probe, &metrics).await?; } None => { - return Err(ProbeError::AbortSet( - anyhow!("No socket for {}, aborting probeset", probe.proto()), + return Err(ProbeErrorWithProbe::AbortSet( + probe_error::StunSnafu.into_error(stun_error::NoSocketSnafu.build()), probe.clone(), )); } @@ -867,8 +951,8 @@ async fn run_probe( .await?; } None => { - return Err(ProbeError::AbortSet( - anyhow!("No QUIC endpoint for {}", probe.proto()), + return Err(ProbeErrorWithProbe::AbortSet( + probe_error::QuicSnafu.into_error(quic_error::NoEndpointSnafu.build()), probe.clone(), )); } @@ -888,7 +972,7 @@ async fn run_stun_probe( net_report: net_report::Addr, probe: Probe, metrics: &Metrics, -) -> Result { +) -> Result { match probe.proto() { ProbeProto::StunIpv4 => debug_assert!(relay_addr.is_ipv4()), ProbeProto::StunIpv6 => debug_assert!(relay_addr.is_ipv6()), @@ -910,10 +994,12 @@ async fn run_stun_probe( inflight_ready_tx, )) .await - .map_err(|e| ProbeError::Error(e.into(), probe.clone()))?; - inflight_ready_rx - .await - .map_err(|e| ProbeError::Error(e.into(), probe.clone()))?; + .map_err(|_| { + ProbeErrorWithProbe::Error(probe_error::ClientGoneSnafu.build(), probe.clone()) + })?; + inflight_ready_rx.await.map_err(|_| { + ProbeErrorWithProbe::Error(probe_error::ClientGoneSnafu.build(), probe.clone()) + })?; // Send the probe. match sock.send_to(&req, relay_addr).await { @@ -928,22 +1014,27 @@ async fn run_stun_probe( result.ipv6_can_send = true; metrics.stun_packets_sent_ipv6.inc(); } - let (delay, addr) = stun_rx - .await - .map_err(|e| ProbeError::Error(e.into(), probe.clone()))?; + let (delay, addr) = stun_rx.await.map_err(|_| { + ProbeErrorWithProbe::Error( + probe_error::StunSnafu.into_error(stun_error::StunChannelGoneSnafu.build()), + probe.clone(), + ) + })?; result.latency = Some(delay); result.addr = Some(addr); Ok(result) } Ok(n) => { - let err = anyhow!("Failed to send full STUN request: {}", probe.proto()); + let err = stun_error::SendFullSnafu.build(); error!(%relay_addr, sent_len=n, req_len=req.len(), "{err:#}"); - Err(ProbeError::Error(err, probe.clone())) + Err(ProbeErrorWithProbe::Error( + probe_error::StunSnafu.into_error(err), + probe.clone(), + )) } Err(err) => { let kind = err.kind(); - let err = anyhow::Error::new(err) - .context(format!("Failed to send STUN request: {}", probe.proto())); + let err = stun_error::SendSnafu.into_error(err); // It is entirely normal that we are on a dual-stack machine with no // routed IPv6 network. So silence that case. @@ -952,11 +1043,17 @@ async fn run_stun_probe( match format!("{kind:?}").as_str() { "NetworkUnreachable" | "HostUnreachable" => { debug!(%relay_addr, "{err:#}"); - Err(ProbeError::AbortSet(err, probe.clone())) + Err(ProbeErrorWithProbe::AbortSet( + probe_error::StunSnafu.into_error(err), + probe.clone(), + )) } _ => { // No need to log this, our caller does already log this. - Err(ProbeError::Error(err, probe.clone())) + Err(ProbeErrorWithProbe::Error( + probe_error::StunSnafu.into_error(err), + probe.clone(), + )) } } } @@ -982,7 +1079,7 @@ async fn run_quic_probe( relay_addr: SocketAddr, probe: Probe, ip_mapped_addrs: Option, -) -> Result { +) -> Result { match probe.proto() { ProbeProto::QuicIpv4 => debug_assert!(relay_addr.is_ipv4()), ProbeProto::QuicIpv6 => debug_assert!(relay_addr.is_ipv6()), @@ -992,18 +1089,28 @@ async fn run_quic_probe( let host = match url.host_str() { Some(host) => host, None => { - return Err(ProbeError::Error( - anyhow!("URL must have 'host' to use QUIC address discovery probes"), + return Err(ProbeErrorWithProbe::Error( + probe_error::QuicSnafu.into_error(quic_error::InvalidUrlSnafu.build()), probe.clone(), )); } }; let quic_client = iroh_relay::quic::QuicClient::new(quic_config.ep, quic_config.client_config) - .map_err(|e| ProbeError::Error(e, probe.clone()))?; + .map_err(|e| { + ProbeErrorWithProbe::Error( + probe_error::QuicSnafu.into_error(quic_error::CreateClientSnafu.into_error(e)), + probe.clone(), + ) + })?; let (addr, latency) = quic_client .get_addr_and_latency(relay_addr, host) .await - .map_err(|e| ProbeError::Error(e, probe.clone()))?; + .map_err(|e| { + ProbeErrorWithProbe::Error( + probe_error::QuicSnafu.into_error(quic_error::GetAddrSnafu.into_error(e)), + probe.clone(), + ) + })?; let mut result = ProbeReport::new(probe.clone()); if matches!(probe, Probe::QuicIpv4 { .. }) { result.ipv4_can_send = true; @@ -1015,6 +1122,19 @@ async fn run_quic_probe( Ok(result) } +#[cfg(not(wasm_browser))] +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +enum CaptivePortalError { + #[snafu(transparent)] + DnsLookup { source: DnsError }, + #[snafu(display("Creating HTTP client failed"))] + CreateReqwestClient { source: reqwest::Error }, + #[snafu(display("HTTP request failed"))] + HttpRequest { source: reqwest::Error }, +} + /// Reports whether or not we think the system is behind a /// captive portal, detected by making a request to a URL that we know should /// return a "204 No Content" response and checking if that's what we get. @@ -1025,7 +1145,7 @@ async fn check_captive_portal( dns_resolver: &DnsResolver, dm: &RelayMap, preferred_relay: Option, -) -> Result { +) -> Result { // If we have a preferred relay node and we can use it for non-STUN requests, try that; // otherwise, pick a random one suitable for non-STUN requests. let preferred_relay = preferred_relay.and_then(|url| match dm.get_node(&url) { @@ -1069,7 +1189,9 @@ async fn check_captive_portal( .collect(); builder = builder.resolve_to_addrs(domain, &addrs); } - let client = builder.build()?; + let client = builder + .build() + .context(captive_portal_error::CreateReqwestClientSnafu)?; // Note: the set of valid characters in a challenge and the total // length is limited; see is_challenge_char in bin/iroh-relay for more @@ -1082,7 +1204,8 @@ async fn check_captive_portal( .request(reqwest::Method::GET, portal_url) .header("X-Tailscale-Challenge", &challenge) .send() - .await?; + .await + .context(captive_portal_error::HttpRequestSnafu)?; let expected_response = format!("response {challenge}"); let is_valid_response = res @@ -1103,30 +1226,49 @@ async fn check_captive_portal( } /// Returns the proper port based on the protocol of the probe. -fn get_port(relay_node: &RelayNode, proto: &ProbeProto) -> Result { +fn get_port(relay_node: &RelayNode, proto: &ProbeProto) -> Option { match proto { #[cfg(not(wasm_browser))] ProbeProto::QuicIpv4 | ProbeProto::QuicIpv6 => { if let Some(ref quic) = relay_node.quic { if quic.port == 0 { - Ok(DEFAULT_RELAY_QUIC_PORT) + Some(DEFAULT_RELAY_QUIC_PORT) } else { - Ok(quic.port) + Some(quic.port) } } else { - bail!("Relay node not suitable for QUIC address discovery probes"); + None } } _ => { if relay_node.stun_port == 0 { - Ok(DEFAULT_STUN_PORT) + Some(DEFAULT_STUN_PORT) } else { - Ok(relay_node.stun_port) + Some(relay_node.stun_port) } } } } +#[cfg(not(wasm_browser))] +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +pub enum GetRelayAddrError { + #[snafu(display("No valid hostname in the relay URL"))] + InvalidHostname, + #[snafu(display("No suitable relay address found"))] + NoAddrFound, + #[snafu(display("DNS lookup failed"))] + DnsLookup { source: DnsError }, + #[snafu(display("Relay node is not suitable for non-STUN probes"))] + UnsupportedRelayNode, + #[snafu(display("HTTPS probes are not implemented"))] + UnsupportedHttps, + #[snafu(display("No port available for this protocol"))] + MissingPort, +} + /// Returns the IP address to use to communicate to this relay node. /// /// *proto* specifies the protocol of the probe. Depending on the protocol we may return @@ -1142,11 +1284,13 @@ async fn get_relay_addr( dns_resolver: &DnsResolver, relay_node: &RelayNode, proto: ProbeProto, -) -> Result { +) -> Result { + use snafu::OptionExt; + if relay_node.stun_only && !matches!(proto, ProbeProto::StunIpv4 | ProbeProto::StunIpv6) { - bail!("Relay node not suitable for non-STUN probes"); + return Err(get_relay_addr_error::UnsupportedRelayNodeSnafu.build()); } - let port = get_port(relay_node, &proto)?; + let port = get_port(relay_node, &proto).context(get_relay_addr_error::MissingPortSnafu)?; match proto { ProbeProto::StunIpv4 | ProbeProto::IcmpV4 | ProbeProto::QuicIpv4 => { @@ -1157,7 +1301,7 @@ async fn get_relay_addr( relay_lookup_ipv6_staggered(dns_resolver, relay_node, port).await } - ProbeProto::Https => Err(anyhow!("Not implemented")), + ProbeProto::Https => Err(get_relay_addr_error::UnsupportedHttpsSnafu.build()), } } @@ -1169,7 +1313,7 @@ async fn relay_lookup_ipv4_staggered( dns_resolver: &DnsResolver, relay: &RelayNode, port: u16, -) -> Result { +) -> Result { match relay.url.host() { Some(url::Host::Domain(hostname)) => { debug!(%hostname, "Performing DNS A lookup for relay addr"); @@ -1181,13 +1325,13 @@ async fn relay_lookup_ipv4_staggered( .next() .map(|ip| ip.to_canonical()) .map(|addr| SocketAddr::new(addr, port)) - .ok_or(anyhow!("No suitable relay addr found")), - Err(err) => Err(err.context("No suitable relay addr found")), + .ok_or(get_relay_addr_error::NoAddrFoundSnafu.build()), + Err(err) => Err(get_relay_addr_error::DnsLookupSnafu.into_error(err)), } } Some(url::Host::Ipv4(addr)) => Ok(SocketAddr::new(addr.into(), port)), - Some(url::Host::Ipv6(_addr)) => Err(anyhow!("No suitable relay addr found")), - None => Err(anyhow!("No valid hostname in RelayUrl")), + Some(url::Host::Ipv6(_addr)) => Err(get_relay_addr_error::NoAddrFoundSnafu.build()), + None => Err(get_relay_addr_error::InvalidHostnameSnafu.build()), } } @@ -1199,7 +1343,7 @@ async fn relay_lookup_ipv6_staggered( dns_resolver: &DnsResolver, relay: &RelayNode, port: u16, -) -> Result { +) -> Result { match relay.url.host() { Some(url::Host::Domain(hostname)) => { debug!(%hostname, "Performing DNS AAAA lookup for relay addr"); @@ -1211,13 +1355,13 @@ async fn relay_lookup_ipv6_staggered( .next() .map(|ip| ip.to_canonical()) .map(|addr| SocketAddr::new(addr, port)) - .ok_or(anyhow!("No suitable relay addr found")), - Err(err) => Err(err.context("No suitable relay addr found")), + .ok_or(get_relay_addr_error::NoAddrFoundSnafu.build()), + Err(err) => Err(get_relay_addr_error::DnsLookupSnafu.into_error(err)), } } - Some(url::Host::Ipv4(_addr)) => Err(anyhow!("No suitable relay addr found")), - Some(url::Host::Ipv6(addr)) => Ok(SocketAddr::new(addr.into(), port)), - None => Err(anyhow!("No valid hostname in RelayUrl")), + Some(url::Host::Ipv4(addr)) => Ok(SocketAddr::new(addr.into(), port)), + Some(url::Host::Ipv6(_addr)) => Err(get_relay_addr_error::NoAddrFoundSnafu.build()), + None => Err(get_relay_addr_error::InvalidHostnameSnafu.build()), } } @@ -1230,7 +1374,7 @@ async fn run_icmp_probe( probe: Probe, relay_addr: SocketAddr, pinger: Pinger, -) -> Result { +) -> Result { match probe.proto() { ProbeProto::IcmpV4 => debug_assert!(relay_addr.is_ipv4()), ProbeProto::IcmpV6 => debug_assert!(relay_addr.is_ipv6()), @@ -1242,12 +1386,13 @@ async fn run_icmp_probe( .send(relay_addr.ip(), DATA) .await .map_err(|err| match err { - PingError::Client(err) => ProbeError::AbortSet( - anyhow!("Failed to create pinger ({err:#}), aborting probeset"), - probe.clone(), - ), + PingError::CreateClientIpv4 { .. } | PingError::CreateClientIpv6 { .. } => { + ProbeErrorWithProbe::AbortSet(probe_error::IcmpSnafu.into_error(err), probe.clone()) + } #[cfg(not(wasm_browser))] - PingError::Ping(err) => ProbeError::Error(err.into(), probe.clone()), + PingError::Ping { .. } => { + ProbeErrorWithProbe::Error(probe_error::IcmpSnafu.into_error(err), probe.clone()) + } })?; debug!(dst = %relay_addr, len = DATA.len(), ?latency, "ICMP ping done"); let mut report = ProbeReport::new(probe); @@ -1265,6 +1410,26 @@ async fn run_icmp_probe( Ok(report) } +#[derive(Debug, Snafu)] +#[snafu(module)] +#[non_exhaustive] +enum MeasureHttpsLatencyError { + #[snafu(transparent)] + InvalidUrl { source: url::ParseError }, + #[cfg(not(wasm_browser))] + #[snafu(transparent)] + DnsLookup { source: DnsError }, + #[cfg(not(wasm_browser))] + #[snafu(display("Invalid certificate"))] + InvalidCertificate { source: reqwest::Error }, + #[snafu(display("Creating HTTP client failed"))] + CreateReqwestClient { source: reqwest::Error }, + #[snafu(display("HTTP request failed"))] + HttpRequest { source: reqwest::Error }, + #[snafu(display("Error response from server {status}: {:?}", status.canonical_reason()))] + InvalidResponse { status: StatusCode }, +} + /// Executes an HTTPS probe. /// /// If `certs` is provided they will be added to the trusted root certificates, allowing the @@ -1274,7 +1439,7 @@ async fn measure_https_latency( #[cfg(not(wasm_browser))] dns_resolver: &DnsResolver, node: &RelayNode, certs: Option>>, -) -> Result<(Duration, IpAddr)> { +) -> Result<(Duration, IpAddr), MeasureHttpsLatencyError> { let url = node.url.join(RELAY_PROBE_PATH)?; // This should also use same connection establishment as relay client itself, which @@ -1307,21 +1472,28 @@ async fn measure_https_latency( #[cfg(not(wasm_browser))] if let Some(certs) = certs { for cert in certs { - let cert = reqwest::Certificate::from_der(&cert)?; + let cert = reqwest::Certificate::from_der(&cert) + .context(measure_https_latency_error::InvalidCertificateSnafu)?; builder = builder.add_root_certificate(cert); } } - let client = builder.build()?; + let client = builder + .build() + .context(measure_https_latency_error::CreateReqwestClientSnafu)?; let start = Instant::now(); - let response = client.request(reqwest::Method::GET, url).send().await?; + let response = client + .request(reqwest::Method::GET, url) + .send() + .await + .context(measure_https_latency_error::HttpRequestSnafu)?; let latency = start.elapsed(); if response.status().is_success() { // Only `None` if a different hyper HttpConnector in the request. #[cfg(not(wasm_browser))] let remote_ip = response .remote_addr() - .context("missing HttpInfo from HttpConnector")? + .expect("missing HttpInfo from HttpConnector") .ip(); #[cfg(wasm_browser)] let remote_ip = IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); @@ -1340,10 +1512,10 @@ async fn measure_https_latency( Ok((latency, remote_ip)) } else { - Err(anyhow!( - "Error response from server: '{}'", - response.status().canonical_reason().unwrap_or_default() - )) + Err(measure_https_latency_error::InvalidResponseSnafu { + status: response.status(), + } + .build()) } } @@ -1442,7 +1614,7 @@ impl Future for MaybeFuture { mod tests { use std::net::{Ipv4Addr, Ipv6Addr}; - use testresult::TestResult; + use n0_snafu::{Result, ResultExt}; use tracing_test::traced_test; use super::{super::test_utils, *}; @@ -1683,10 +1855,10 @@ mod tests { ); break; } - Err(ProbeError::Error(err, _probe)) => { + Err(ProbeErrorWithProbe::Error(err, _probe)) => { last_err = Some(err); } - Err(ProbeError::AbortSet(_err, _probe)) => { + Err(ProbeErrorWithProbe::AbortSet(_err, _probe)) => { // We don't have permission, too bad. // panic!("no ping permission: {err:#}"); break; @@ -1699,7 +1871,7 @@ mod tests { } #[tokio::test] - async fn test_measure_https_latency() -> TestResult { + async fn test_measure_https_latency() -> Result { let (server, relay) = test_utils::relay().await; let dns_resolver = dns::tests::resolver(); tracing::info!(relay_url = ?relay.url , "RELAY_URL"); @@ -1711,20 +1883,21 @@ mod tests { let relay_url_ip = relay .url .host_str() - .context("host")? - .parse::()?; + .unwrap() + .parse::() + .e()?; assert_eq!(ip, relay_url_ip); Ok(()) } #[tokio::test] #[traced_test] - async fn test_quic_probe() -> TestResult { + async fn test_quic_probe() -> Result { let (server, relay) = test_utils::relay().await; let relay = Arc::new(relay); let client_config = iroh_relay::client::make_dangerous_client_config(); - let ep = quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))?; - let client_addr = ep.local_addr()?; + let ep = quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)).e()?; + let client_addr = ep.local_addr().e()?; let quic_addr_disc = QuicConfig { ep: ep.clone(), client_config, @@ -1748,7 +1921,7 @@ mod tests { { Ok(probe) => probe, Err(e) => match e { - ProbeError::AbortSet(err, _) | ProbeError::Error(err, _) => { + ProbeErrorWithProbe::AbortSet(err, _) | ProbeErrorWithProbe::Error(err, _) => { return Err(err.into()); } }, diff --git a/iroh/src/net_report/reportgen/hairpin.rs b/iroh/src/net_report/reportgen/hairpin.rs index 7e62e3dd726..097a65a53ef 100644 --- a/iroh/src/net_report/reportgen/hairpin.rs +++ b/iroh/src/net_report/reportgen/hairpin.rs @@ -14,13 +14,13 @@ use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; -use anyhow::{bail, Context, Result}; use iroh_relay::protos::stun; use n0_future::{ task::{self, AbortOnDropHandle}, time::{self, Instant}, }; use netwatch::UdpSocket; +use snafu::Snafu; use tokio::sync::oneshot; use tracing::{debug, error, info_span, trace, warn, Instrument}; @@ -88,6 +88,18 @@ struct Actor { reportgen: reportgen::Addr, } +#[derive(Debug, Snafu)] +enum Error { + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(display("net_report actor is gone"))] + NetReportActorGone, + #[snafu(display("reportgen actor is gone"))] + ReportGenActorGone, + #[snafu(display("stun response channel dropped"))] + StunResponseGone, +} + impl Actor { async fn run(self) { match self.run_inner().await { @@ -96,8 +108,8 @@ impl Actor { } } - async fn run_inner(self) -> Result<()> { - let socket = UdpSocket::bind_v4(0).context("Failed to bind hairpin socket on 0.0.0.0:0")?; + async fn run_inner(self) -> Result<(), Error> { + let socket = UdpSocket::bind_v4(0)?; if let Err(err) = Self::prepare_hairpin(&socket).await { warn!("unable to send hairpin prep: {err:#}"); @@ -121,8 +133,11 @@ impl Actor { self.net_report .send(net_report::Message::InFlightStun(inflight, msg_response_tx)) .await - .context("net_report actor gone")?; - msg_response_rx.await.context("net_report actor died")?; + .map_err(|_| NetReportActorGoneSnafu.build())?; + + msg_response_rx + .await + .map_err(|_| NetReportActorGoneSnafu.build())?; if let Err(err) = socket.send_to(&stun::request(txn), dst.into()).await { warn!(%dst, "failed to send hairpin check"); @@ -132,7 +147,7 @@ impl Actor { let now = Instant::now(); let hairpinning_works = match time::timeout(HAIRPIN_CHECK_TIMEOUT, stun_rx).await { Ok(Ok(_)) => true, - Ok(Err(_)) => bail!("net_report actor dropped stun response channel"), + Ok(Err(_)) => return Err(StunResponseGoneSnafu.build()), Err(_) => false, // Elapsed }; debug!( @@ -144,14 +159,14 @@ impl Actor { self.reportgen .send(super::Message::HairpinResult(hairpinning_works)) .await - .context("Failed to send hairpin result to reportgen actor")?; + .map_err(|_| ReportGenActorGoneSnafu.build())?; trace!("reportgen notified"); Ok(()) } - async fn prepare_hairpin(socket: &UdpSocket) -> Result<()> { + async fn prepare_hairpin(socket: &UdpSocket) -> std::io::Result<()> { // At least the Apple Airport Extreme doesn't allow hairpin // sends from a private socket until it's seen traffic from // that src IP:port to something else out on the internet. diff --git a/iroh/src/net_report/reportgen/probes.rs b/iroh/src/net_report/reportgen/probes.rs index 3169e474373..39eb2d90695 100644 --- a/iroh/src/net_report/reportgen/probes.rs +++ b/iroh/src/net_report/reportgen/probes.rs @@ -6,12 +6,12 @@ use std::{collections::BTreeSet, fmt, sync::Arc}; -use anyhow::{ensure, Result}; use iroh_base::RelayUrl; use iroh_relay::{RelayMap, RelayNode}; use n0_future::time::Duration; #[cfg(not(wasm_browser))] use netwatch::interfaces; +use snafu::Snafu; use crate::net_report::Report; @@ -189,6 +189,10 @@ pub(super) struct ProbeSet { probes: Vec, } +#[derive(Debug, Snafu)] +#[snafu(display("Mismatching probe"))] +struct PushError; + impl ProbeSet { fn new(proto: ProbeProto) -> Self { Self { @@ -197,8 +201,10 @@ impl ProbeSet { } } - fn push(&mut self, probe: Probe) -> Result<()> { - ensure!(probe.proto() == self.proto, "mismatching probe proto"); + fn push(&mut self, probe: Probe) -> Result<(), PushError> { + if probe.proto() != self.proto { + return Err(PushError); + } self.probes.push(probe); Ok(()) } diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index 56abe7e29c3..9702ce5e1b2 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -3,10 +3,9 @@ //! ## Example //! //! ```no_run -//! # use anyhow::Result; -//! # use iroh::{endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr}; +//! # use iroh::{endpoint::{Connection, BindError}, protocol::{ProtocolHandler, Router, ProtocolError}, Endpoint, NodeAddr}; //! # -//! # async fn test_compile() -> Result<()> { +//! # async fn test_compile() -> Result<(), BindError> { //! let endpoint = Endpoint::builder().discovery_n0().bind().await?; //! //! let router = Router::builder(endpoint) @@ -20,7 +19,7 @@ //! struct Echo; //! //! impl ProtocolHandler for Echo { -//! async fn accept(&self, connection: Connection) -> Result<()> { +//! async fn accept(&self, connection: Connection) -> Result<(), ProtocolError> { //! let (mut send, mut recv) = connection.accept_bi().await?; //! //! // Echo any bytes received back directly. @@ -35,18 +34,18 @@ //! ``` use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc}; -use anyhow::Result; use iroh_base::NodeId; use n0_future::{ join_all, task::{self, AbortOnDropHandle, JoinSet}, }; +use snafu::{Backtrace, Snafu}; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; use tracing::{error, info_span, trace, warn, Instrument}; use crate::{ - endpoint::{Connecting, Connection}, + endpoint::{Connecting, Connection, RemoteNodeIdError}, Endpoint, }; @@ -64,11 +63,11 @@ use crate::{ /// /// ```no_run /// # use std::sync::Arc; -/// # use anyhow::Result; /// # use futures_lite::future::Boxed as BoxedFuture; +/// # use n0_snafu::ResultExt; /// # use iroh::{endpoint::Connecting, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr}; /// # -/// # async fn test_compile() -> Result<()> { +/// # async fn test_compile() -> n0_snafu::Result<()> { /// let endpoint = Endpoint::builder().discovery_n0().bind().await?; /// /// let router = Router::builder(endpoint) @@ -76,8 +75,8 @@ use crate::{ /// .spawn(); /// /// // wait until the user wants to -/// tokio::signal::ctrl_c().await?; -/// router.shutdown().await?; +/// tokio::signal::ctrl_c().await.context("ctrl+c")?; +/// router.shutdown().await.context("shutdown")?; /// # Ok(()) /// # } /// ``` @@ -96,6 +95,49 @@ pub struct RouterBuilder { protocols: ProtocolMap, } +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ProtocolError { + #[snafu(transparent)] + Connect { + source: crate::endpoint::ConnectionError, + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, + #[snafu(transparent)] + RemoteNodeId { source: RemoteNodeIdError }, + #[snafu(display("Not allowed."))] + NotAllowed {}, + + #[snafu(transparent)] + User { + source: Box, + }, +} + +impl ProtocolError { + /// Creates a new user error from an arbitrary error type. + pub fn from_err(value: T) -> Self { + Self::User { + source: Box::new(value), + } + } +} + +impl From for ProtocolError { + fn from(err: std::io::Error) -> Self { + Self::from_err(err) + } +} + +impl From for ProtocolError { + fn from(err: quinn::ClosedStream) -> Self { + Self::from_err(err) + } +} + /// Handler for incoming connections. /// /// A router accepts connections for arbitrary ALPN protocols. @@ -116,7 +158,7 @@ pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static { fn on_connecting( &self, connecting: Connecting, - ) -> impl Future> + Send { + ) -> impl Future> + Send { async move { let conn = connecting.await?; Ok(conn) @@ -132,7 +174,10 @@ pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static { /// When [`Router::shutdown`] is called, no further connections will be accepted, and /// the futures returned by [`Self::accept`] will be aborted after the future returned /// from [`ProtocolHandler::shutdown`] completes. - fn accept(&self, connection: Connection) -> impl Future> + Send; + fn accept( + &self, + connection: Connection, + ) -> impl Future> + Send; /// Called when the router shuts down. /// @@ -146,11 +191,11 @@ pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static { } impl ProtocolHandler for Arc { - async fn on_connecting(&self, conn: Connecting) -> Result { + async fn on_connecting(&self, conn: Connecting) -> Result { self.as_ref().on_connecting(conn).await } - async fn accept(&self, conn: Connection) -> Result<()> { + async fn accept(&self, conn: Connection) -> Result<(), ProtocolError> { self.as_ref().accept(conn).await } @@ -160,11 +205,11 @@ impl ProtocolHandler for Arc { } impl ProtocolHandler for Box { - async fn on_connecting(&self, conn: Connecting) -> Result { + async fn on_connecting(&self, conn: Connecting) -> Result { self.as_ref().on_connecting(conn).await } - async fn accept(&self, conn: Connection) -> Result<()> { + async fn accept(&self, conn: Connection) -> Result<(), ProtocolError> { self.as_ref().accept(conn).await } @@ -182,7 +227,7 @@ pub(crate) trait DynProtocolHandler: Send + Sync + std::fmt::Debug + 'static { fn on_connecting( &self, connecting: Connecting, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + '_>> { Box::pin(async move { let conn = connecting.await?; Ok(conn) @@ -193,7 +238,7 @@ pub(crate) trait DynProtocolHandler: Send + Sync + std::fmt::Debug + 'static { fn accept( &self, connection: Connection, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + '_>>; /// See [`ProtocolHandler::shutdown`]. fn shutdown(&self) -> Pin + Send + '_>> { @@ -205,14 +250,14 @@ impl DynProtocolHandler for P { fn accept( &self, connection: Connection, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + '_>> { Box::pin(::accept(self, connection)) } fn on_connecting( &self, connecting: Connecting, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + '_>> { Box::pin(::on_connecting(self, connecting)) } @@ -276,7 +321,7 @@ impl Router { /// /// If some [`ProtocolHandler`] panicked in the accept loop, this will propagate /// that panic into the result here. - pub async fn shutdown(&self) -> Result<()> { + pub async fn shutdown(&self) -> Result<(), n0_future::task::JoinError> { if self.is_shutdown() { return Ok(()); } @@ -471,16 +516,19 @@ impl AccessLimit

{ } impl ProtocolHandler for AccessLimit

{ - fn on_connecting(&self, conn: Connecting) -> impl Future> + Send { + fn on_connecting( + &self, + conn: Connecting, + ) -> impl Future> + Send { self.proto.on_connecting(conn) } - async fn accept(&self, conn: Connection) -> Result<()> { + async fn accept(&self, conn: Connection) -> Result<(), ProtocolError> { let remote = conn.remote_node_id()?; let is_allowed = (self.limiter)(remote); if !is_allowed { conn.close(0u32.into(), b"not allowed"); - anyhow::bail!("not allowed"); + return Err(NotAllowedSnafu.build()); } self.proto.accept(conn).await?; Ok(()) @@ -495,21 +543,21 @@ impl ProtocolHandler for AccessLimit

{ mod tests { use std::{sync::Mutex, time::Duration}; + use n0_snafu::{Result, ResultExt}; use quinn::ApplicationClose; - use testresult::TestResult; use super::*; use crate::{endpoint::ConnectionError, watcher::Watcher, RelayMode}; #[tokio::test] - async fn test_shutdown() -> Result<()> { + async fn test_shutdown() -> Result { let endpoint = Endpoint::builder().bind().await?; let router = Router::builder(endpoint.clone()).spawn(); assert!(!router.is_shutdown()); assert!(!endpoint.is_closed()); - router.shutdown().await?; + router.shutdown().await.e()?; assert!(router.is_shutdown()); assert!(endpoint.is_closed()); @@ -524,7 +572,7 @@ mod tests { const ECHO_ALPN: &[u8] = b"/iroh/echo/1"; impl ProtocolHandler for Echo { - async fn accept(&self, connection: Connection) -> Result<()> { + async fn accept(&self, connection: Connection) -> Result<(), ProtocolError> { println!("accepting echo"); let (mut send, mut recv) = connection.accept_bi().await?; @@ -538,7 +586,7 @@ mod tests { } } #[tokio::test] - async fn test_limiter() -> Result<()> { + async fn test_limiter() -> Result { let e1 = Endpoint::builder().bind().await?; // deny all access let proto = AccessLimit::new(Echo, |_node_id| false); @@ -551,18 +599,18 @@ mod tests { println!("connecting"); let conn = e2.connect(addr1, ECHO_ALPN).await?; - let (_send, mut recv) = conn.open_bi().await?; + let (_send, mut recv) = conn.open_bi().await.e()?; let response = recv.read_to_end(1000).await.unwrap_err(); assert!(format!("{:#?}", response).contains("not allowed")); - r1.shutdown().await?; + r1.shutdown().await.e()?; e2.close().await; Ok(()) } #[tokio::test] - async fn test_graceful_shutdown() -> TestResult { + async fn test_graceful_shutdown() -> Result { #[derive(Debug, Clone, Default)] struct TestProtocol { connections: Arc>>, @@ -571,7 +619,7 @@ mod tests { const TEST_ALPN: &[u8] = b"/iroh/test/1"; impl ProtocolHandler for TestProtocol { - async fn accept(&self, connection: Connection) -> Result<()> { + async fn accept(&self, connection: Connection) -> Result<(), ProtocolError> { self.connections.lock().expect("poisoned").push(connection); Ok(()) } @@ -600,7 +648,7 @@ mod tests { .await?; let conn = endpoint2.connect(addr, TEST_ALPN).await?; - router.shutdown().await?; + router.shutdown().await.e()?; let reason = conn.closed().await; assert_eq!( diff --git a/iroh/src/test_utils.rs b/iroh/src/test_utils.rs index 455675b2e0e..368c66b1394 100644 --- a/iroh/src/test_utils.rs +++ b/iroh/src/test_utils.rs @@ -1,13 +1,12 @@ //! Internal utilities to support testing. use std::net::Ipv4Addr; -use anyhow::Result; pub use dns_and_pkarr_servers::DnsPkarrServer; use iroh_base::RelayUrl; use iroh_relay::{ server::{ - AccessConfig, CertConfig, QuicConfig, RelayConfig, Server, ServerConfig, StunConfig, - TlsConfig, + AccessConfig, CertConfig, QuicConfig, RelayConfig, Server, ServerConfig, SpawnError, + StunConfig, TlsConfig, }, RelayMap, RelayNode, RelayQuicConfig, }; @@ -29,7 +28,7 @@ pub struct CleanupDropGuard(pub(crate) oneshot::Sender<()>); /// /// The returned `Url` is the url of the relay server in the returned [`RelayMap`]. /// When dropped, the returned [`Server`] does will stop running. -pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, Server)> { +pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, Server), SpawnError> { run_relay_server_with( Some(StunConfig { bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), @@ -43,7 +42,7 @@ pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, Server)> { /// /// The returned `Url` is the url of the relay server in the returned [`RelayMap`]. /// When dropped, the returned [`Server`] does will stop running. -pub async fn run_relay_server_with_stun() -> Result<(RelayMap, RelayUrl, Server)> { +pub async fn run_relay_server_with_stun() -> Result<(RelayMap, RelayUrl, Server), SpawnError> { run_relay_server_with( Some(StunConfig { bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), @@ -65,7 +64,7 @@ pub async fn run_relay_server_with_stun() -> Result<(RelayMap, RelayUrl, Server) pub async fn run_relay_server_with( stun: Option, quic: bool, -) -> Result<(RelayMap, RelayUrl, Server)> { +) -> Result<(RelayMap, RelayUrl, Server), SpawnError> { let (certs, server_config) = iroh_relay::server::testing::self_signed_tls_certs_and_config(); let tls = TlsConfig { @@ -92,11 +91,12 @@ pub async fn run_relay_server_with( }), quic, stun, - #[cfg(feature = "metrics")] - metrics_addr: None, + ..Default::default() }; let server = Server::spawn(config).await?; - let url: RelayUrl = format!("https://{}", server.https_addr().expect("configured")).parse()?; + let url: RelayUrl = format!("https://{}", server.https_addr().expect("configured")) + .parse() + .expect("invalid relay url"); let quic = server .quic_addr() @@ -114,7 +114,6 @@ pub async fn run_relay_server_with( pub(crate) mod dns_and_pkarr_servers { use std::{net::SocketAddr, time::Duration}; - use anyhow::Result; use iroh_base::{NodeId, SecretKey}; use url::Url; @@ -146,12 +145,12 @@ pub(crate) mod dns_and_pkarr_servers { impl DnsPkarrServer { /// Run DNS and Pkarr servers on localhost. - pub async fn run() -> anyhow::Result { + pub async fn run() -> std::io::Result { Self::run_with_origin("dns.iroh.test".to_string()).await } /// Run DNS and Pkarr servers on localhost with the specified `node_origin` domain. - pub async fn run_with_origin(node_origin: String) -> anyhow::Result { + pub async fn run_with_origin(node_origin: String) -> std::io::Result { let state = State::new(node_origin.clone()); let (nameserver, dns_drop_guard) = run_dns_server(state.clone()).await?; let (pkarr_url, pkarr_drop_guard) = run_pkarr_relay(state.clone()).await?; @@ -184,7 +183,7 @@ pub(crate) mod dns_and_pkarr_servers { /// Wait until a Pkarr announce for a node is published to the server. /// /// If `timeout` elapses an error is returned. - pub async fn on_node(&self, node_id: &NodeId, timeout: Duration) -> Result<()> { + pub async fn on_node(&self, node_id: &NodeId, timeout: Duration) -> std::io::Result<()> { self.state.on_node(node_id, timeout).await } } @@ -196,7 +195,6 @@ pub(crate) mod dns_server { net::{Ipv4Addr, SocketAddr}, }; - use anyhow::{ensure, Result}; use hickory_resolver::proto::{ op::{header::MessageType, Message}, serialize::binary::BinDecodable, @@ -213,18 +211,19 @@ pub(crate) mod dns_server { &self, query: &Message, reply: &mut Message, - ) -> impl Future> + Send; + ) -> impl Future> + Send; } - pub type QueryHandlerFunction = - Box BoxFuture> + Send + Sync + 'static>; + pub type QueryHandlerFunction = Box< + dyn Fn(&Message, &mut Message) -> BoxFuture> + Send + Sync + 'static, + >; impl QueryHandler for QueryHandlerFunction { fn resolve( &self, query: &Message, reply: &mut Message, - ) -> impl Future> + Send { + ) -> impl Future> + Send { (self)(query, reply) } } @@ -234,7 +233,7 @@ pub(crate) mod dns_server { /// Must pass a [`QueryHandler`] that answers queries. pub async fn run_dns_server( resolver: impl QueryHandler, - ) -> Result<(SocketAddr, CleanupDropGuard)> { + ) -> std::io::Result<(SocketAddr, CleanupDropGuard)> { let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); let socket = UdpSocket::bind(bind_addr).await?; let bound_addr = socket.local_addr()?; @@ -261,7 +260,7 @@ pub(crate) mod dns_server { } impl TestDnsServer { - async fn run(self) -> Result<()> { + async fn run(self) -> std::io::Result<()> { let mut buf = [0; 1450]; loop { let res = self.socket.recv_from(&mut buf).await; @@ -272,7 +271,7 @@ pub(crate) mod dns_server { } } - async fn handle_datagram(&self, from: SocketAddr, buf: &[u8]) -> Result<()> { + async fn handle_datagram(&self, from: SocketAddr, buf: &[u8]) -> std::io::Result<()> { let packet = Message::from_bytes(buf)?; debug!(queries = ?packet.queries(), %from, "received query"); let mut reply = packet.clone(); @@ -281,7 +280,7 @@ pub(crate) mod dns_server { debug!(?reply, %from, "send reply"); let buf = reply.to_vec()?; let len = self.socket.send_to(&buf, from).await?; - ensure!(len == buf.len(), "failed to send complete packet"); + assert_eq!(len, buf.len(), "failed to send complete packet"); Ok(()) } } @@ -293,7 +292,6 @@ pub(crate) mod pkarr_relay { net::{Ipv4Addr, SocketAddr}, }; - use anyhow::Result; use axum::{ extract::{Path, State}, response::IntoResponse, @@ -308,7 +306,7 @@ pub(crate) mod pkarr_relay { use super::CleanupDropGuard; use crate::test_utils::pkarr_dns_state::State as AppState; - pub async fn run_pkarr_relay(state: AppState) -> Result<(Url, CleanupDropGuard)> { + pub async fn run_pkarr_relay(state: AppState) -> std::io::Result<(Url, CleanupDropGuard)> { let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); let app = Router::new() .route("/pkarr/{key}", put(pkarr_put)) @@ -341,15 +339,16 @@ pub(crate) mod pkarr_relay { Path(key): Path, body: Bytes, ) -> Result { - let key = pkarr::PublicKey::try_from(key.as_str())?; - let signed_packet = pkarr::SignedPacket::from_relay_payload(&key, &body)?; + let key = pkarr::PublicKey::try_from(key.as_str()).map_err(std::io::Error::other)?; + let signed_packet = + pkarr::SignedPacket::from_relay_payload(&key, &body).map_err(std::io::Error::other)?; let _updated = state.upsert(signed_packet)?; Ok(http::StatusCode::NO_CONTENT) } #[derive(Debug)] - struct AppError(anyhow::Error); - impl> From for AppError { + struct AppError(std::io::Error); + impl> From for AppError { fn from(value: T) -> Self { Self(value.into()) } @@ -370,7 +369,6 @@ pub(crate) mod pkarr_dns_state { time::Duration, }; - use anyhow::{bail, Result}; use iroh_base::NodeId; use iroh_relay::node_info::{NodeIdExt, NodeInfo, IROH_TXT_NAME}; use pkarr::SignedPacket; @@ -397,20 +395,21 @@ pub(crate) mod pkarr_dns_state { self.notify.notified() } - pub async fn on_node(&self, node: &NodeId, timeout: Duration) -> Result<()> { + pub async fn on_node(&self, node: &NodeId, timeout: Duration) -> std::io::Result<()> { let timeout = tokio::time::sleep(timeout); tokio::pin!(timeout); while self.get(node, |p| p.is_none()) { tokio::select! { - _ = &mut timeout => bail!("timeout"), + _ = &mut timeout => return Err(std::io::Error::other("timeout")), _ = self.on_update() => {} } } Ok(()) } - pub fn upsert(&self, signed_packet: SignedPacket) -> anyhow::Result { - let node_id = NodeId::from_bytes(&signed_packet.public_key().to_bytes())?; + pub fn upsert(&self, signed_packet: SignedPacket) -> std::io::Result { + let node_id = NodeId::from_bytes(&signed_packet.public_key().to_bytes()) + .map_err(std::io::Error::other)?; let mut map = self.packets.lock().expect("poisoned"); let updated = match map.entry(node_id) { hash_map::Entry::Vacant(e) => { @@ -447,7 +446,7 @@ pub(crate) mod pkarr_dns_state { query: &hickory_resolver::proto::op::Message, reply: &mut hickory_resolver::proto::op::Message, ttl: u32, - ) -> Result<()> { + ) -> std::io::Result<()> { for query in query.queries() { let domain_name = query.name().to_string(); let Some(node_id) = node_id_from_domain_name(&domain_name) else { @@ -456,12 +455,13 @@ pub(crate) mod pkarr_dns_state { self.get(&node_id, |packet| { if let Some(packet) = packet { - let node_info = NodeInfo::from_pkarr_signed_packet(packet)?; - for record in node_info_to_hickory_records(&node_info, &self.origin, ttl)? { + let node_info = NodeInfo::from_pkarr_signed_packet(packet) + .map_err(std::io::Error::other)?; + for record in node_info_to_hickory_records(&node_info, &self.origin, ttl) { reply.add_answer(record); } } - anyhow::Ok(()) + Ok::<_, std::io::Error>(()) })?; } Ok(()) @@ -473,7 +473,7 @@ pub(crate) mod pkarr_dns_state { &self, query: &hickory_resolver::proto::op::Message, reply: &mut hickory_resolver::proto::op::Message, - ) -> impl Future> + Send { + ) -> impl Future> + Send { const TTL: u32 = 30; let res = self.resolve_dns(query, reply, TTL); std::future::ready(res) @@ -503,10 +503,10 @@ pub(crate) mod pkarr_dns_state { node_info: &NodeInfo, origin: &str, ttl: u32, - ) -> Result + 'static> { + ) -> impl Iterator + 'static { let txt_strings = node_info.to_txt_strings(); - let records = to_hickory_records(txt_strings, node_info.node_id, origin, ttl)?; - Ok(records.collect::>().into_iter()) + let records = to_hickory_records(txt_strings, node_info.node_id, origin, ttl); + records.collect::>().into_iter() } /// Converts to a list of [`hickory_resolver::proto::rr::Record`] resource records. @@ -515,25 +515,24 @@ pub(crate) mod pkarr_dns_state { node_id: NodeId, origin: &str, ttl: u32, - ) -> Result + '_> { + ) -> impl Iterator + '_ { use hickory_resolver::proto::rr; let name = format!("{}.{}.{}", IROH_TXT_NAME, node_id.to_z32(), origin); - let name = rr::Name::from_utf8(name)?; - let records = txt_strings.into_iter().map(move |s| { + let name = rr::Name::from_utf8(name).expect("invalid name"); + txt_strings.into_iter().map(move |s| { let txt = rr::rdata::TXT::new(vec![s]); let rdata = rr::RData::TXT(txt); rr::Record::from_rdata(name.clone(), ttl, rdata) - }); - Ok(records) + }) } #[cfg(test)] mod tests { use iroh_base::NodeId; - use testresult::TestResult; + use n0_snafu::Result; #[test] - fn test_node_id_from_domain_name() -> TestResult { + fn test_node_id_from_domain_name() -> Result { let name = "_iroh.dgjpkxyn3zyrk3zfads5duwdgbqpkwbjxfj4yt7rezidr3fijccy.dns.iroh.link."; let node_id = super::node_id_from_domain_name(name); let expected: NodeId = diff --git a/iroh/src/tls/certificate.rs b/iroh/src/tls/certificate.rs index 7970ebf9ad3..96a8cf689ee 100644 --- a/iroh/src/tls/certificate.rs +++ b/iroh/src/tls/certificate.rs @@ -7,6 +7,7 @@ use der::{asn1::OctetStringRef, Decode, Encode, Sequence}; use iroh_base::{PublicKey, SecretKey, Signature}; +use snafu::Snafu; use x509_parser::prelude::*; /// The libp2p Public Key Extension is a X.509 extension @@ -104,19 +105,25 @@ pub(crate) struct P2pExtension { } /// An error that occurs during certificate generation. -#[derive(Debug, thiserror::Error)] -#[error(transparent)] -pub(crate) struct GenError(#[from] rcgen::Error); +#[derive(Debug, Snafu)] +#[snafu(transparent)] +pub struct GenError { + source: rcgen::Error, +} /// An error that occurs during certificate parsing. -#[derive(Debug, thiserror::Error)] -#[error(transparent)] -pub(crate) struct ParseError(#[from] pub(crate) webpki::Error); +#[derive(Debug, Snafu)] +#[snafu(transparent)] +pub(crate) struct ParseError { + pub(crate) source: webpki::Error, +} /// An error that occurs during signature verification. -#[derive(Debug, thiserror::Error)] -#[error(transparent)] -pub(crate) struct VerificationError(#[from] pub(crate) webpki::Error); +#[derive(Debug, Snafu)] +#[snafu(transparent)] +pub(crate) struct VerificationError { + pub(crate) source: webpki::Error, +} /// Internal function that only parses but does not verify the certificate. /// diff --git a/iroh/src/tls/resolver.rs b/iroh/src/tls/resolver.rs index dc8cec9170d..3783b4ec152 100644 --- a/iroh/src/tls/resolver.rs +++ b/iroh/src/tls/resolver.rs @@ -2,6 +2,8 @@ use std::sync::Arc; use ed25519_dalek::pkcs8::{spki::der::pem::LineEnding, EncodePrivateKey}; use iroh_base::SecretKey; +use nested_enum_utils::common_fields; +use snafu::Snafu; use webpki_types::{pem::PemObject, CertificateDer, PrivatePkcs8KeyDer}; use super::certificate; @@ -14,14 +16,20 @@ pub(super) struct AlwaysResolvesCert { } /// Error for generating TLS configs. -#[derive(Debug, thiserror::Error)] +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[derive(Debug, Snafu)] +#[non_exhaustive] pub(super) enum CreateConfigError { /// Error generating the certificate. - #[error("Error generating the certificate")] - CertError(#[from] certificate::GenError), + #[snafu(display("Error generating the certificate"), context(false))] + CertError { source: certificate::GenError }, /// Rustls configuration error - #[error("rustls error")] - Rustls(#[from] rustls::Error), + #[snafu(display("rustls error"), context(false))] + Rustls { source: rustls::Error }, } impl AlwaysResolvesCert { diff --git a/iroh/src/tls/verifier.rs b/iroh/src/tls/verifier.rs index 81be334e2ad..ceec1b00a45 100644 --- a/iroh/src/tls/verifier.rs +++ b/iroh/src/tls/verifier.rs @@ -321,9 +321,9 @@ fn verify_tls13_signature( } impl From for rustls::Error { - fn from(certificate::ParseError(e): certificate::ParseError) -> Self { + fn from(certificate::ParseError { source }: certificate::ParseError) -> Self { use webpki::Error::*; - match e { + match source { BadDer => rustls::Error::InvalidCertificate(CertificateError::BadEncoding), e => { rustls::Error::InvalidCertificate(CertificateError::Other(OtherError(Arc::new(e)))) @@ -332,9 +332,9 @@ impl From for rustls::Error { } } impl From for rustls::Error { - fn from(certificate::VerificationError(e): certificate::VerificationError) -> Self { + fn from(certificate::VerificationError { source }: certificate::VerificationError) -> Self { use webpki::Error::*; - match e { + match source { InvalidSignatureForPublicKey => { rustls::Error::InvalidCertificate(CertificateError::BadSignature) } diff --git a/iroh/src/watcher.rs b/iroh/src/watcher.rs index 309e60bb15e..bc8a0ebf74a 100644 --- a/iroh/src/watcher.rs +++ b/iroh/src/watcher.rs @@ -25,6 +25,7 @@ use std::{ #[cfg(iroh_loom)] use loom::sync; +use snafu::Snafu; use sync::{Mutex, RwLock}; /// A wrapper around a value that notifies [`Watcher`]s when the value is modified. @@ -173,7 +174,7 @@ pub trait Watcher: Clone { initial: match self.get() { Ok(Some(value)) => Some(Ok(value)), Ok(None) => None, - Err(Disconnected) => Some(Err(Disconnected)), + Err(val) => Some(Err(val)), }, watcher: self, } @@ -261,7 +262,7 @@ impl Watcher for Direct { type Value = T; fn get(&self) -> Result { - let shared = self.shared.upgrade().ok_or(Disconnected)?; + let shared = self.shared.upgrade().ok_or(DisconnectedSnafu.build())?; Ok(shared.get()) } @@ -270,7 +271,7 @@ impl Watcher for Direct { cx: &mut task::Context<'_>, ) -> Poll> { let Some(shared) = self.shared.upgrade() else { - return Poll::Ready(Err(Disconnected)); + return Poll::Ready(Err(DisconnectedSnafu.build())); }; match shared.poll_updated(cx, self.epoch) { Poll::Pending => Poll::Pending, @@ -415,7 +416,7 @@ where } match self.as_mut().watcher.poll_updated(cx) { Poll::Ready(Ok(value)) => Poll::Ready(Some(value)), - Poll::Ready(Err(Disconnected)) => Poll::Ready(None), + Poll::Ready(Err(_)) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } @@ -423,9 +424,13 @@ where /// The error for when a [`Watcher`] is disconnected from its underlying /// [`Watchable`] value, because of that watchable having been dropped. -#[derive(thiserror::Error, Debug)] -#[error("Watcher lost connection to underlying Watchable, it was dropped")] -pub struct Disconnected; +#[derive(Debug, Snafu)] +#[snafu(display("Watch lost connection to underlying Watchable, it was dropped"))] +pub struct Disconnected { + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +} // Private: diff --git a/iroh/tests/integration.rs b/iroh/tests/integration.rs index 0ccac01cb2f..72aeadfd0cc 100644 --- a/iroh/tests/integration.rs +++ b/iroh/tests/integration.rs @@ -18,7 +18,7 @@ use n0_future::{ time::{self, Duration}, StreamExt, }; -use testresult::TestResult; +use n0_snafu::{Result, ResultExt}; #[cfg(not(wasm_browser))] use tokio::test; use tracing::{info_span, Instrument}; @@ -33,7 +33,7 @@ use wasm_bindgen_test::wasm_bindgen_test as test; const ECHO_ALPN: &[u8] = b"echo"; #[test] -async fn simple_node_id_based_connection_transfer() -> TestResult { +async fn simple_node_id_based_connection_transfer() -> Result { setup_logging(); let client = Endpoint::builder().discovery_n0().bind().await?; @@ -48,24 +48,24 @@ async fn simple_node_id_based_connection_transfer() -> TestResult { let server = server.clone(); async move { while let Some(incoming) = server.accept().await { - let conn = incoming.await?; + let conn = incoming.await.e()?; let node_id = conn.remote_node_id()?; tracing::info!(node_id = %node_id.fmt_short(), "Accepted connection"); - let (mut send, mut recv) = conn.accept_bi().await?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; let mut bytes_sent = 0; - while let Some(chunk) = recv.read_chunk(10_000, true).await? { + while let Some(chunk) = recv.read_chunk(10_000, true).await.e()? { bytes_sent += chunk.bytes.len(); - send.write_chunk(chunk.bytes).await?; + send.write_chunk(chunk.bytes).await.e()?; } - send.finish()?; + send.finish().e()?; tracing::info!("Copied over {bytes_sent} byte(s)"); let code = conn.closed().await; tracing::info!("Closed with code: {code:?}"); } - TestResult::Ok(()) + Ok::<_, n0_snafu::Error>(()) } .instrument(info_span!("server")) }); @@ -92,18 +92,19 @@ async fn simple_node_id_based_connection_transfer() -> TestResult { } } }) - .await?; + .await + .e()?; tracing::info!(to = %server.node_id().fmt_short(), "Opening a connection"); let conn = client.connect(server.node_id(), ECHO_ALPN).await?; tracing::info!("Connection opened"); - let (mut send, mut recv) = conn.open_bi().await?; - send.write_all(b"Hello, World!").await?; - send.finish()?; + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(b"Hello, World!").await.e()?; + send.finish().e()?; tracing::info!("Sent request"); - let response = recv.read_to_end(10_000).await?; + let response = recv.read_to_end(10_000).await.e()?; tracing::info!(len = response.len(), "Received response"); assert_eq!(&response, b"Hello, World!");