diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index a82569bca..017e6e37f 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -7,6 +7,16 @@ license = "Apache-2.0" [dependencies] url = "2.5.0" -tokio = { version = "1.37.0", features = ["sync"] } +tokio = { version = "1.37.0", features = ["sync", "rt", "net", "time", "macros"] } tonic = { version = "0.13.0", path = "../tonic", default-features = false, features = ["codegen"] } futures-core = "0.3.31" +once_cell = "1.19.0" +hickory-resolver = { version = "0.25.1", optional = true } +rand = "0.8.5" + +[dev-dependencies] +hickory-server = "0.25.2" + +[features] +default = ["hickory_dns"] +hickory_dns = ["dep:hickory-resolver"] diff --git a/grpc/src/client/mod.rs b/grpc/src/client/mod.rs index c9e2365ce..4108411b3 100644 --- a/grpc/src/client/mod.rs +++ b/grpc/src/client/mod.rs @@ -19,9 +19,9 @@ use std::fmt::Display; pub mod channel; -pub mod service; pub(crate) mod load_balancing; pub(crate) mod name_resolution; +pub mod service; pub mod service_config; /// A representation of the current state of a gRPC channel, also used for the diff --git a/grpc/src/client/name_resolution/backoff.rs b/grpc/src/client/name_resolution/backoff.rs new file mode 100644 index 000000000..894155f80 --- /dev/null +++ b/grpc/src/client/name_resolution/backoff.rs @@ -0,0 +1,235 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use rand::Rng; +use std::time::Duration; + +#[derive(Clone)] +pub struct BackoffConfig { + /// The amount of time to backoff after the first failure. + pub base_delay: Duration, + + /// The factor with which to multiply backoffs after a + /// failed retry. Should ideally be greater than 1. + pub multiplier: f64, + + /// The factor with which backoffs are randomized. + pub jitter: f64, + + /// The upper bound of backoff delay. + pub max_delay: Duration, +} + +pub struct ExponentialBackoff { + config: BackoffConfig, + + /// The delay for the next retry, without the random jitter. Store as f64 + /// to avoid rounding errors. + next_delay_secs: f64, +} + +/// This is a backoff configuration with the default values specified +/// at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md. +/// +/// This should be useful for callers who want to configure backoff with +/// non-default values only for a subset of the options. +pub const DEFAULT_EXPONENTIAL_CONFIG: BackoffConfig = BackoffConfig { + base_delay: Duration::from_secs(1), + multiplier: 1.6, + jitter: 0.2, + max_delay: Duration::from_secs(120), +}; + +impl BackoffConfig { + fn validate(&self) -> Result<(), &'static str> { + // Valid that params are in valid ranges. + // 0 <= base_dealy <= max_delay + if self.base_delay > self.max_delay { + Err("base_delay must be greater than max_delay")?; + } + // 1 <= multiplier + if self.multiplier < 1.0 { + Err("multiplier must be greater than 1.0")?; + } + // 0 <= jitter <= 1 + if self.jitter < 0.0 { + Err("jitter must be greater than or equal to 0")?; + } + if self.jitter > 1.0 { + Err("jitter must be less than or equal to 1")? + } + Ok(()) + } +} + +impl ExponentialBackoff { + pub fn new(config: BackoffConfig) -> Result { + config.validate()?; + let next_delay_secs = config.base_delay.as_secs_f64(); + Ok(ExponentialBackoff { + config, + next_delay_secs: next_delay_secs, + }) + } + + pub fn reset(&mut self) { + self.next_delay_secs = self.config.base_delay.as_secs_f64(); + } + + pub fn backoff_duration(&mut self) -> Duration { + let next_delay = self.next_delay_secs; + let cur_delay = + next_delay * (1.0 + self.config.jitter * rand::thread_rng().gen_range(-1.0..1.0)); + self.next_delay_secs = self + .config + .max_delay + .as_secs_f64() + .min(next_delay * self.config.multiplier); + Duration::from_secs_f64(cur_delay) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use crate::client::name_resolution::backoff::{ + BackoffConfig, ExponentialBackoff, DEFAULT_EXPONENTIAL_CONFIG, + }; + + // Epsilon for floating point comparisons if needed, though Duration + // comparisons are often better. + const EPSILON: f64 = 1e-9; + + #[test] + fn default_config_is_valid() { + let result = ExponentialBackoff::new(DEFAULT_EXPONENTIAL_CONFIG.clone()); + assert_eq!(result.is_ok(), true); + } + + #[test] + fn base_less_than_max() { + let config = BackoffConfig { + base_delay: Duration::from_secs(10), + multiplier: 123.0, + jitter: 0.0, + max_delay: Duration::from_secs(100), + }; + let mut backoff = ExponentialBackoff::new(config).unwrap(); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + } + + #[test] + fn base_more_than_max() { + let config = BackoffConfig { + multiplier: 123.0, + jitter: 0.0, + base_delay: Duration::from_secs(100), + max_delay: Duration::from_secs(10), + }; + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); + } + + #[test] + fn negative_multiplier() { + let config = BackoffConfig { + multiplier: -123.0, + jitter: 0.0, + base_delay: Duration::from_secs(10), + max_delay: Duration::from_secs(100), + }; + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); + } + + #[test] + fn negative_jitter() { + let config = BackoffConfig { + multiplier: 1.0, + jitter: -10.0, + base_delay: Duration::from_secs(10), + max_delay: Duration::from_secs(100), + }; + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); + } + + #[test] + fn jitter_greater_than_one() { + let config = BackoffConfig { + multiplier: 1.0, + jitter: 2.0, + base_delay: Duration::from_secs(10), + max_delay: Duration::from_secs(100), + }; + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); + } + + #[test] + fn backoff_reset_no_jitter() { + let config = BackoffConfig { + multiplier: 2.0, + jitter: 0.0, + base_delay: Duration::from_secs(1), + max_delay: Duration::from_secs(15), + }; + let mut backoff = ExponentialBackoff::new(config.clone()).unwrap(); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(1)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(2)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(4)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(8)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + // Duration is capped to max_delay. + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + + // reset and repeat. + backoff.reset(); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(1)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(2)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(4)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(8)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + // Duration is capped to max_delay. + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + } + + #[test] + fn backoff_with_jitter() { + let config = BackoffConfig { + multiplier: 2.0, + jitter: 0.2, + base_delay: Duration::from_secs(1), + max_delay: Duration::from_secs(15), + }; + let mut backoff = ExponentialBackoff::new(config.clone()).unwrap(); + // 0.8 <= duration <= 1.2. + let duration = backoff.backoff_duration(); + assert_eq!(duration.gt(&Duration::from_secs_f64(0.8 - EPSILON)), true); + assert_eq!(duration.lt(&Duration::from_secs_f64(1.2 + EPSILON)), true); + // 1.6 <= duration <= 2.4. + let duration = backoff.backoff_duration(); + assert_eq!(duration.gt(&Duration::from_secs_f64(1.6 - EPSILON)), true); + assert_eq!(duration.lt(&Duration::from_secs_f64(2.4 + EPSILON)), true); + // 3.2 <= duration <= 4.8. + let duration = backoff.backoff_duration(); + assert_eq!(duration.gt(&Duration::from_secs_f64(3.2 - EPSILON)), true); + assert_eq!(duration.lt(&Duration::from_secs_f64(4.8 + EPSILON)), true); + } +} diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs new file mode 100644 index 000000000..76abe5be0 --- /dev/null +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -0,0 +1,397 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +//! This module implements a DNS resolver to be installed as the default resolver +//! in grpc. + +use std::{ + net::{IpAddr, SocketAddr}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, + time::{Duration, SystemTime}, +}; + +use tokio::sync::Notify; +use url::Host; + +use crate::{ + client::name_resolution::{Address, NopResolver, ResolverUpdate, TCP_IP_NETWORK_TYPE}, + rt, +}; + +use super::{ + backoff::{BackoffConfig, ExponentialBackoff, DEFAULT_EXPONENTIAL_CONFIG}, + global_registry, Endpoint, Resolver, ResolverBuilder, +}; + +#[cfg(test)] +mod test; + +const DEFAULT_PORT: u16 = 443; +const DEFAULT_DNS_PORT: u16 = 53; + +/// This specifies the maximum duration for a DNS resolution request. +/// If the timeout expires before a response is received, the request will be +/// canceled. +/// +/// It is recommended to set this value at application startup. Avoid modifying +/// this variable after initialization. +static RESOLVING_TIMEOUT_MS: AtomicU64 = AtomicU64::new(30_000); // 30 seconds + +/// This is the minimum interval at which re-resolutions are allowed. This helps +/// to prevent excessive re-resolution. +static MIN_RESOLUTION_INTERVAL_MS: AtomicU64 = AtomicU64::new(30_000); // 30 seconds + +fn get_resolving_timeout() -> Duration { + Duration::from_millis(RESOLVING_TIMEOUT_MS.load(Ordering::Relaxed)) +} + +/// Sets the maximum duration for DNS resolution requests. +/// +/// This function affects the global timeout used by all channels using the DNS +/// name resolver scheme. +/// +/// It must be called only at application startup, before any gRPC calls are +/// made. +/// +/// The default value is 30 seconds. Setting the timeout too low may result in +/// premature timeouts during resolution, while setting it too high may lead to +/// unnecessary delays in service discovery. Choose a value appropriate for your +/// specific needs and network environment. +pub fn set_resolving_timeout(duration: Duration) { + RESOLVING_TIMEOUT_MS.store(duration.as_millis() as u64, Ordering::Relaxed); +} + +fn get_min_resolution_interval() -> Duration { + Duration::from_millis(MIN_RESOLUTION_INTERVAL_MS.load(Ordering::Relaxed)) +} + +/// Sets the default minimum interval at which DNS re-resolutions are allowed. +/// This helps to prevent excessive re-resolution. +/// +/// It must be called only at application startup, before any gRPC calls are +/// made. +pub fn set_min_resolution_interval(duration: Duration) { + MIN_RESOLUTION_INTERVAL_MS.store(duration.as_millis() as u64, Ordering::Relaxed); +} + +pub fn reg() { + global_registry().add_builder(Box::new(Builder {})); +} + +struct Builder {} + +struct DnsOptions { + min_resolution_interval: Duration, + resolving_timeout: Duration, + backoff_config: BackoffConfig, +} + +impl DnsResolver { + fn new( + target: &super::Target, + options: super::ResolverOptions, + dns_opts: DnsOptions, + ) -> Box { + let parsed = match parse_endpoint_and_authority(target) { + Ok(res) => res, + Err(err) => return nop_resolver_for_err(err.to_string(), options), + }; + let endpoint = parsed.endpoint; + let host = match endpoint.host { + Host::Domain(d) => d, + Host::Ipv4(ipv4) => { + return nop_resolver_for_ip(IpAddr::V4(ipv4), endpoint.port, options) + } + Host::Ipv6(ipv6) => { + return nop_resolver_for_ip(IpAddr::V6(ipv6), endpoint.port, options) + } + }; + let authority = parsed.authority; + let dns = match options.runtime.get_dns_resolver(rt::ResolverOptions { + server_addr: authority, + }) { + Ok(dns) => dns, + Err(err) => return nop_resolver_for_err(err.to_string(), options), + }; + let state = Arc::new(Mutex::new(InternalState { + addrs: Ok(Vec::new()), + channel_response: None, + })); + let state_copy = state.clone(); + let resolve_now_notify = Arc::new(Notify::new()); + let channel_updated_notify = Arc::new(Notify::new()); + let channel_updated_rx = channel_updated_notify.clone(); + let resolve_now_rx = resolve_now_notify.clone(); + + let handle = options.runtime.clone().spawn(Box::pin(async move { + let mut backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone()) + .expect("default exponential config must be valid"); + let state = state_copy; + let work_scheduler = options.work_scheduler; + loop { + let mut lookup_fut = dns.lookup_host_name(&host); + let mut timeout_fut = options.runtime.sleep(dns_opts.resolving_timeout); + let addrs = tokio::select! { + result = &mut lookup_fut => { + match result { + Ok(ips) => { + let addrs = ips + .into_iter() + .map(|ip| SocketAddr::new(ip, endpoint.port)) + .collect(); + Ok(addrs) + } + Err(err) => Err(err), + } + } + _ = &mut timeout_fut => { + Err("Timed out waiting for DNS resolution".to_string()) + } + }; + { + let mut internal_state = match state.lock() { + Ok(state) => state, + Err(_) => return, + }; + internal_state.addrs = addrs; + } + work_scheduler.schedule_work(); + channel_updated_rx.notified().await; + let channel_response: Option; + { + channel_response = state.lock().unwrap().channel_response.take(); + } + let next_resoltion_time: SystemTime; + if channel_response.is_some() { + next_resoltion_time = SystemTime::now() + .checked_add(backoff.backoff_duration()) + .unwrap(); + } else { + // Success resolving, wait for the next resolve_now. However, + // also wait MIN_RESOLUTION_INTERVAL at the very least to prevent + // constantly re-resolving. + backoff.reset(); + next_resoltion_time = SystemTime::now() + .checked_add(dns_opts.min_resolution_interval) + .unwrap(); + _ = resolve_now_rx.notified().await; + } + // Wait till next resolution time. + let Ok(duration) = next_resoltion_time.duration_since(SystemTime::now()) else { + continue; // Time has already passed. + }; + options.runtime.sleep(duration).await; + } + })); + + Box::new(DnsResolver { + state, + task_handle: handle, + resolve_now_notifier: resolve_now_notify, + channel_update_notifier: channel_updated_notify, + }) + } +} + +impl ResolverBuilder for Builder { + fn build( + &self, + target: &super::Target, + options: super::ResolverOptions, + ) -> Box { + let dns_opts = DnsOptions { + min_resolution_interval: get_min_resolution_interval(), + resolving_timeout: get_resolving_timeout(), + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + DnsResolver::new(target, options, dns_opts) + } + + fn scheme(&self) -> &'static str { + "dns" + } + + fn is_valid_uri(&self, target: &super::Target) -> bool { + if let Err(err) = parse_endpoint_and_authority(target) { + eprintln!("{}", err); + false + } else { + true + } + } +} + +struct DnsResolver { + state: Arc>, + task_handle: Box, + resolve_now_notifier: Arc, + channel_update_notifier: Arc, +} + +struct InternalState { + addrs: Result, String>, + // Error from the latest call to channel_controller.update(). + channel_response: Option, +} + +impl Resolver for DnsResolver { + fn resolve_now(&mut self) { + _ = self.resolve_now_notifier.notify_one(); + } + + fn work(&mut self, channel_controller: &mut dyn super::ChannelController) { + let mut state = self.state.lock().unwrap(); + let endpoint_result = match &state.addrs { + Ok(addrs) => { + let endpoints: Vec<_> = addrs + .iter() + .map(|a| Endpoint { + addresses: vec![Address { + network_type: TCP_IP_NETWORK_TYPE, + address: a.to_string(), + ..Default::default() + }], + ..Default::default() + }) + .collect(); + Ok(endpoints) + } + Err(err) => Err(err.to_string()), + }; + let update = ResolverUpdate { + endpoints: endpoint_result, + ..Default::default() + }; + let status = channel_controller.update(update); + state.channel_response = status.err(); + _ = self.channel_update_notifier.notify_one(); + } +} + +impl Drop for DnsResolver { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +#[derive(Eq, PartialEq, Debug)] +struct HostPort { + host: Host, + port: u16, +} + +#[derive(Eq, PartialEq, Debug)] +struct ParseResult { + endpoint: HostPort, + authority: Option, +} + +fn parse_endpoint_and_authority(target: &super::Target) -> Result { + // Parse the endpoint. + let endpoint = target.path(); + let endpoint = endpoint.strip_prefix("/").unwrap_or(endpoint); + let parse_result = parse_host_port(endpoint, DEFAULT_PORT) + .map_err(|err| format!("Failed to parse target {}: {}", target, err))?; + let endpoint = parse_result.ok_or("Received empty endpoint host.".to_string())?; + + // Parse the authority. + let authority = target.authority_host_port(); + if authority.is_empty() { + return Ok(ParseResult { + endpoint, + authority: None, + }); + } + let parse_result = parse_host_port(&authority, DEFAULT_DNS_PORT) + .map_err(|err| format!("Failed to parse DNS authority {}: {}", target, err))?; + let Some(authority) = parse_result else { + return Ok(ParseResult { + endpoint, + authority: None, + }); + }; + let authority = match authority.host { + Host::Ipv4(ipv4) => SocketAddr::new(IpAddr::V4(ipv4), authority.port), + Host::Ipv6(ipv6) => SocketAddr::new(IpAddr::V6(ipv6), authority.port), + _ => { + return Err(format!("Received non-IP DNS authority {}", authority.host)); + } + }; + Ok(ParseResult { + endpoint, + authority: Some(authority), + }) +} + +/// Takes the user input string of the format "host:port" and default port, +/// returns the parsed host and port. If string doesn't specify a port, the +/// default_port is returned. If the string doesn't specify the host, +/// Result is returned. +fn parse_host_port(host_and_port: &str, default_port: u16) -> Result, String> { + // We need to use the https scheme otherwise url::Url::parse doesn't convert + // IP addresses to Host::Ipv4 or Host::Ipv6 if they could represent valid + // domains. + let url = format!("https://{}", host_and_port); + let url = url.parse::().map_err(|err| err.to_string())?; + let port = url.port().unwrap_or(default_port); + let host = match url.host() { + Some(host) => host, + None => return Ok(None), + }; + // Convert the domain to an owned string. + let host = match host { + Host::Domain(s) => Host::Domain(s.to_owned()), + Host::Ipv4(ip) => Host::Ipv4(ip), + Host::Ipv6(ip) => Host::Ipv6(ip), + }; + Ok(Some(HostPort { host, port })) +} + +fn nop_resolver_for_ip( + ip: IpAddr, + port: u16, + options: super::ResolverOptions, +) -> Box { + options.work_scheduler.schedule_work(); + Box::new(NopResolver { + update: ResolverUpdate { + endpoints: Ok(vec![Endpoint { + addresses: vec![Address { + network_type: TCP_IP_NETWORK_TYPE, + address: SocketAddr::new(ip, port).to_string(), + ..Default::default() + }], + ..Default::default() + }]), + ..Default::default() + }, + }) +} + +fn nop_resolver_for_err(err: String, options: super::ResolverOptions) -> Box { + options.work_scheduler.schedule_work(); + Box::new(NopResolver { + update: ResolverUpdate { + endpoints: Err(err), + ..Default::default() + }, + }) +} diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs new file mode 100644 index 000000000..5fad96e4a --- /dev/null +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -0,0 +1,505 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{future::Future, pin::Pin, sync::Arc, time::Duration}; + +use tokio::sync::mpsc::{self, UnboundedSender}; + +use crate::{ + client::name_resolution::{ + self, + backoff::{BackoffConfig, DEFAULT_EXPONENTIAL_CONFIG}, + dns::{parse_endpoint_and_authority, HostPort}, + global_registry, ResolverOptions, ResolverUpdate, Target, + }, + rt::{self, tokio::TokioRuntime}, +}; + +use super::ParseResult; + +const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10); + +#[test] +pub fn target_parsing() { + struct TestCase { + input: &'static str, + want_result: Result, + } + let test_cases = vec![ + TestCase { + input: "dns:///grpc.io", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 443, + }, + authority: None, + }), + }, + TestCase { + input: "dns:///grpc.io:1234", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: None, + }), + }, + TestCase { + input: "dns://8.8.8.8/grpc.io:1234", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: Some("8.8.8.8:53".parse().unwrap()), + }), + }, + TestCase { + input: "dns://8.8.8.8:5678/grpc.io:1234/abc", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: Some("8.8.8.8:5678".parse().unwrap()), + }), + }, + TestCase { + input: "dns://[::1]:5678/grpc.io:1234/abc", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: Some("[::1]:5678".parse().unwrap()), + }), + }, + TestCase { + input: "dns://[fe80::1]:5678/127.0.0.1:1234/abc", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Ipv4("127.0.0.1".parse().unwrap()), + port: 1234, + }, + authority: Some("[fe80::1]:5678".parse().unwrap()), + }), + }, + TestCase { + input: "dns:///[fe80::1%80]:5678/abc", + want_result: Err("SocketAddr doesn't support IPv6 addresses with zones".to_string()), + }, + TestCase { + input: "dns:///:5678/abc", + want_result: Err("Empty host with port".to_string()), + }, + TestCase { + input: "dns:///grpc.io:abc/abc", + want_result: Err("Non numeric port".to_string()), + }, + TestCase { + input: "dns:///grpc.io:/", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 443, + }, + authority: None, + }), + }, + TestCase { + input: "dns:///:", + want_result: Err("No host and port".to_string()), + }, + TestCase { + input: "dns:///[2001:db8:a0b:12f0::1", + want_result: Err("Invalid address".to_string()), + }, + ]; + + for tc in test_cases { + let target: Target = tc.input.parse().unwrap(); + let got = parse_endpoint_and_authority(&target); + if got.is_err() != tc.want_result.is_err() { + panic!( + "Got error {:?}, want error: {:?}", + got.err(), + tc.want_result.err() + ); + } + if got.is_err() { + continue; + } + assert_eq!(got.unwrap(), tc.want_result.unwrap()); + } +} + +struct WorkScheduler { + work_tx: UnboundedSender<()>, +} + +impl name_resolution::WorkScheduler for WorkScheduler { + fn schedule_work(&self) { + self.work_tx.send(()).unwrap(); + } +} + +struct FakeChannelController { + update_result: Result<(), String>, + update_tx: UnboundedSender, +} + +impl name_resolution::ChannelController for FakeChannelController { + fn update(&mut self, update: name_resolution::ResolverUpdate) -> Result<(), String> { + println!("Received resolver update: {:?}", &update); + self.update_tx.send(update).unwrap(); + self.update_result.clone() + } + + fn parse_service_config( + &self, + _: &str, + ) -> Result { + Err("Unimplemented".to_string()) + } +} + +#[tokio::test] +pub async fn dns_basic() { + super::reg(); + let builder = global_registry().get("dns").unwrap(); + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let mut resolver = builder.build(target, opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // A successful endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); +} + +#[tokio::test] +pub async fn invalid_target() { + super::reg(); + let builder = global_registry().get("dns").unwrap(); + let target = &"dns:///:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let mut resolver = builder.build(target, opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // An error endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!( + update + .endpoints + .err() + .unwrap() + .contains(&target.to_string()), + true + ); +} + +#[derive(Clone)] +struct FakeDns { + latency: Duration, + lookup_result: Result, String>, +} + +#[tonic::async_trait] +impl rt::DnsResolver for FakeDns { + async fn lookup_host_name(&self, _: &str) -> Result, String> { + tokio::time::sleep(self.latency).await; + self.lookup_result.clone() + } + + async fn lookup_txt(&self, _: &str) -> Result, String> { + Err("unimplemented".to_string()) + } +} + +struct FakeRuntime { + inner: TokioRuntime, + dns: FakeDns, +} + +impl rt::Runtime for FakeRuntime { + fn spawn( + &self, + task: Pin + Send + 'static>>, + ) -> Box { + self.inner.spawn(task) + } + + fn get_dns_resolver(&self, _: rt::ResolverOptions) -> Result, String> { + Ok(Box::new(self.dns.clone())) + } + + fn sleep(&self, duration: std::time::Duration) -> Pin> { + self.inner.sleep(duration) + } +} + +#[tokio::test] +pub async fn dns_lookup_error() { + super::reg(); + let builder = global_registry().get("dns").unwrap(); + let target = &"dns:///grpc.io:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let runtime = FakeRuntime { + inner: TokioRuntime {}, + dns: FakeDns { + latency: Duration::from_secs(0), + lookup_result: Err("test_error".to_string()), + }, + }; + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(runtime), + work_scheduler: work_scheduler.clone(), + }; + let mut resolver = builder.build(target, opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // An error endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.err().unwrap().contains("test_error"), true); +} + +#[tokio::test] +pub async fn dns_lookup_timeout() { + let target = &"dns:///grpc.io:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let runtime = FakeRuntime { + inner: TokioRuntime {}, + dns: FakeDns { + latency: Duration::from_secs(20), + lookup_result: Ok(Vec::new()), + }, + }; + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(runtime), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: super::get_min_resolution_interval(), + resolving_timeout: DEFAULT_TEST_SHORT_TIMEOUT, + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + + // An error endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.err().unwrap().contains("Timed out"), true); +} + +#[tokio::test] +pub async fn rate_limit() { + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: Duration::from_secs(20), + resolving_timeout: super::get_resolving_timeout(), + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + // Wait for schedule work to be called. + let event = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // A successful endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + + // Call resolve_now repeatedly, new updates should not be produced. + for _ in 0..5 { + resolver.resolve_now(); + tokio::select! { + _ = work_rx.recv() => { + panic!("Received unexpected work request from resolver: {:?}", event); + } + _ = tokio::time::sleep(DEFAULT_TEST_SHORT_TIMEOUT) => { + println!("No work requested from resolver."); + } + }; + } +} + +#[tokio::test] +pub async fn re_resolution_after_success() { + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: Duration::from_millis(1), + resolving_timeout: super::get_resolving_timeout(), + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // A successful endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + + // Call resolve_now, a new update should be produced. + resolver.resolve_now(); + let _ = work_rx.recv().await.unwrap(); + resolver.work(&mut channel_controller); + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); +} + +#[tokio::test] +pub async fn backoff_on_error() { + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: Duration::from_millis(1), + resolving_timeout: super::get_resolving_timeout(), + // Speed up the backoffs to make the test run faster. + backoff_config: BackoffConfig { + base_delay: Duration::from_millis(1), + multiplier: 1.0, + jitter: 0.0, + max_delay: Duration::from_millis(1), + }, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Err("test_error".to_string()), + }; + + // As the channel returned an error to the resolver, the resolver will + // backoff and re-attempt resolution. + for _ in 0..5 { + let _ = work_rx.recv().await.unwrap(); + resolver.work(&mut channel_controller); + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + } + + // This time the channel accepts the resolver update. + channel_controller.update_result = Ok(()); + let _ = work_rx.recv().await.unwrap(); + resolver.work(&mut channel_controller); + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + + // Since the channel controller returns Ok(), the resolver will stop + // producing more updates. + tokio::select! { + _ = work_rx.recv() => { + panic!("Received unexpected work request from resolver."); + } + _ = tokio::time::sleep(DEFAULT_TEST_SHORT_TIMEOUT) => { + println!("No event received from resolver."); + } + }; +} diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 05c5cf795..5c808250c 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -23,95 +23,213 @@ //! a service. use core::fmt; +use super::service_config::ServiceConfig; +use crate::{attributes::Attributes, rt}; use std::{ - error::Error, fmt::{Display, Formatter}, hash::Hash, + str::FromStr, sync::Arc, }; -use tokio::sync::Notify; -use tonic::async_trait; -use url::Url; +mod backoff; +mod dns; +mod registry; +pub use registry::global_registry; -use crate::attributes::Attributes; +/// Target represents a target for gRPC, as specified in: +/// https://github.com/grpc/grpc/blob/master/doc/naming.md. +/// It is parsed from the target string that gets passed during channel creation +/// by the user. gRPC passes it to the resolver and the balancer. +/// +/// If the target follows the naming spec, and the parsed scheme is registered +/// with gRPC, we will parse the target string according to the spec. If the +/// target does not contain a scheme or if the parsed scheme is not registered +/// (i.e. no corresponding resolver available to resolve the endpoint), we will +/// apply the default scheme, and will attempt to reparse it. +#[derive(Debug, Clone)] +pub struct Target { + url: url::Url, +} -use super::service_config::ServiceConfig; +impl FromStr for Target { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.parse::() { + Ok(url) => Ok(Target { url }), + Err(err) => Err(err.to_string()), + } + } +} + +impl Target { + pub fn scheme(&self) -> &str { + self.url.scheme() + } + + /// The host part of the authority. + pub fn authority_host(&self) -> &str { + self.url.host_str().unwrap_or("") + } + + /// The port part of the authority. + pub fn authority_port(&self) -> Option { + self.url.port() + } + + /// Returns either host:port or host depending on the existence of the port + /// in the authority. + pub fn authority_host_port(&self) -> String { + let host = self.authority_host(); + let port = self.authority_port(); + if let Some(port) = port { + format!("{}:{}", host, port) + } else { + host.to_owned() + } + } + + /// Return the path for this target URL, as a percent-encoded ASCII string. + pub fn path(&self) -> &str { + self.url.path() + } +} + +impl Display for Target { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}://{}{}", + self.scheme(), + self.authority_host_port(), + self.path() + ) + } +} /// A name resolver factory that produces Resolver instances used by the channel /// to resolve network addresses for the target URI. pub trait ResolverBuilder: Send + Sync { - /// Builds and returns a new name resolver instance. + /// Builds a name resolver instance. /// /// Note that build must not fail. Instead, an erroring Resolver may be /// returned that calls ChannelController.update() with an Err value. - fn build( - &self, - target: Url, - resolve_now: Arc, - options: ResolverOptions, - ) -> Box; + fn build(&self, target: &Target, options: ResolverOptions) -> Box; /// Reports the URI scheme handled by this name resolver. fn scheme(&self) -> &'static str; - /// Returns the default authority for a channel using this name resolver and - /// target. This is typically the same as the service's name. By default, - /// the default_authority method automatically returns the path portion of - /// the target URI, with the leading prefix removed. - fn default_authority(&self, target: &Url) -> String { + /// Returns the default authority for a channel using this name resolver + /// and target. This is typically the same as the service's name. By + /// default, the default_authority method automatically returns the path + /// portion of the target URI, with the leading prefix removed. + fn default_authority(&self, target: &Target) -> String { let path = target.path(); path.strip_prefix("/").unwrap_or(path).to_string() } + + /// Returns a bool indicating whether the input uri is valid to create a + /// resolver. + fn is_valid_uri(&self, uri: &Target) -> bool; } /// A collection of data configured on the channel that is constructing this /// name resolver. -#[derive(Debug, Default)] #[non_exhaustive] pub struct ResolverOptions { /// The authority that will be used for the channel by default. This /// contains either the result of the default_authority method of this /// ResolverBuilder, or another string if the channel was configured to /// override the default. - authority: String, + pub authority: String, + + /// The runtime which provides utilities to do async work. + pub runtime: Arc, + + /// A hook into the channel's work scheduler that allows the Resolver to + /// request the ability to perform operations on the ChannelController. + pub work_scheduler: Arc, } -#[async_trait] -/// A collection of operations a Resolver may perform on the channel which -/// constructed it. -pub trait ChannelController: Send + Sync { - /// Parses the provided JSON service config. - fn parse_config(&self, config: &str) -> Result>; // TODO +/// Used to asynchronously request a call into the Resolver's work method. +pub trait WorkScheduler: Send + Sync { + // Schedules a call into the Resolver's work method. If there is already a + // pending work call that has not yet started, this may not schedule another + // call. + fn schedule_work(&self); +} + +/// Resolver watches for the updates on the specified target. +/// Updates include address updates and service config updates. +pub trait Resolver: Send { + /// Asks the resolver to obtain an updated resolver result, if + /// applicable. + /// + /// This is useful for polling resolvers to decide when to re-resolve. + /// However, the implementation is not required to + /// re-resolve immediately upon receiving this call; it may instead + /// elect to delay based on some configured minimum time between + /// queries, to avoid hammering the name service with queries. + /// + /// For watch based resolvers, this may be a no-op. + fn resolve_now(&mut self); + /// Called serially by the channel to to allow access to ChannelController. + fn work(&mut self, channel_controller: &mut dyn ChannelController); +} + +/// The `ChannelController` trait provides the resolver with functionality +/// to interact with the channel. +pub trait ChannelController: Send + Sync { /// Notifies the channel about the current state of the name resolver. If /// an error value is returned, the name resolver should attempt to /// re-resolve, if possible. The resolver is responsible for applying an /// appropriate backoff mechanism to avoid overloading the system or the /// remote resolver. - async fn update(&self, update: ResolverUpdate) -> Result<(), Box>; -} + fn update(&mut self, update: ResolverUpdate) -> Result<(), String>; -/// A name resolver update expresses the current state of the resolver. -pub enum ResolverUpdate { - /// Indicates the name resolver encountered an error. - Err(Box), - /// Indicates the name resolver produced a valid result. - Data(ResolverData), + /// Parses the provided JSON service config and returns an instance of a + /// ParsedServiceConfig. + fn parse_service_config(&self, config: &str) -> Result; } -/// Data provided by the name resolver to the channel. -#[derive(Debug, Default)] +#[derive(Clone, Debug)] #[non_exhaustive] -pub struct ResolverData { +/// ResolverUpdate contains the current Resolver state relevant to the +/// channel. +pub struct ResolverUpdate { + /// Attributes contains arbitrary data about the resolver intended for + /// consumption by the load balancing policy. + pub attributes: Arc, + /// A list of endpoints which each identify a logical host serving the /// service indicated by the target URI. - pub endpoints: Vec, + pub endpoints: Result, String>, + /// The service config which the client should use for communicating with - /// the service. - pub service_config: Option, - // Optional data which may be used by the LB Policy or channel. - pub attributes: Attributes, + /// the service. If it is None, it indicates no service config is present or + /// the resolver does not provide service configs. + pub service_config: Result, String>, + + /// An optional human-readable note describing context about the + /// resolution, to be passed along to the LB policy for inclusion in + /// RPC failure status messages in cases where neither endpoints nor + /// service_config has a non-OK status. For example, a resolver that + /// returns an empty endpoint list but a valid service config may set + /// to this to something like "no DNS entries found for ". + pub resolution_note: Option, +} + +impl Default for ResolverUpdate { + fn default() -> Self { + ResolverUpdate { + service_config: Ok(Default::default()), + attributes: Default::default(), + endpoints: Ok(Default::default()), + resolution_note: Default::default(), + } + } } /// An Endpoint is an address or a collection of addresses which reference one @@ -120,9 +238,28 @@ pub struct ResolverData { #[derive(Debug, Default, Clone)] #[non_exhaustive] pub struct Endpoint { - /// The list of addresses used to connect to the server. + /// Addresses contains a list of addresses used to access this endpoint. pub addresses: Vec
, - /// Optional data which may be used by the LB policy or channel. + + /// Attributes contains arbitrary data about this endpoint intended for + /// consumption by the LB policy. + pub attributes: Attributes, +} + +/// An Address is an identifier that indicates how to connect to a server. +#[non_exhaustive] +#[derive(Debug, Clone, Default)] +pub struct Address { + /// The network type is used to identify what kind of transport to create + /// when connecting to this address. Typically TCP_IP_ADDRESS_TYPE. + pub network_type: &'static str, + + /// The address itself is passed to the transport in order to create a + /// connection to it. + pub address: String, + + /// Attributes contains arbitrary data about this address intended for + /// consumption by the subchannel. pub attributes: Attributes, } @@ -140,20 +277,6 @@ impl Hash for Endpoint { } } -/// An Address is an identifier that indicates how to connect to a server. -#[derive(Debug, Default, Clone)] -#[non_exhaustive] -pub struct Address { - /// The address type is used to identify what kind of transport to create - /// when connecting to this address. Typically TCP_IP_ADDRESS_TYPE. - pub address_type: String, // TODO: &'static str? - /// The address itself is passed to the transport in order to create a - /// connection to it. - pub address: String, - // Optional data which the transport may use for the connection. - pub attributes: Attributes, -} - impl Eq for Address {} impl PartialEq for Address { @@ -170,20 +293,92 @@ impl Hash for Address { impl Display for Address { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.address_type, self.address) + write!(f, "{}:{}", self.network_type, self.address) } } /// Indicates the address is an IPv4 or IPv6 address that should be connected to /// via TCP/IP. -pub static TCP_IP_ADDRESS_TYPE: &str = "tcp"; - -/// A name resolver instance. -#[async_trait] -pub trait Resolver: Send + Sync { - /// The entry point of the resolver. Will only be called once by the - /// channel. Should not return unless the resolver never will need to - /// update its state. The future will be dropped when the channel shuts - /// down or enters idle mode. - async fn run(&mut self, channel_controller: Box); +pub static TCP_IP_NETWORK_TYPE: &str = "tcp"; + +// A resolver that returns the same result every time it's work method is called. +// It can be used to return an error to the channel when a resolver fails to +// build. +struct NopResolver { + pub update: ResolverUpdate, +} + +impl Resolver for NopResolver { + fn resolve_now(&mut self) {} + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + let _ = channel_controller.update(self.update.clone()); + } +} + +#[cfg(test)] +mod test { + use super::Target; + + #[test] + pub fn parse_target() { + #[derive(Default)] + struct TestCase { + input: &'static str, + want_scheme: &'static str, + want_host: &'static str, + want_port: Option, + want_host_port: &'static str, + want_path: &'static str, + want_str: &'static str, + } + let test_cases = vec![ + TestCase { + input: "dns:///grpc.io", + want_scheme: "dns", + want_host_port: "", + want_host: "", + want_port: None, + want_path: "/grpc.io", + want_str: "dns:///grpc.io", + }, + TestCase { + input: "dns://8.8.8.8:53/grpc.io/docs", + want_scheme: "dns", + want_host_port: "8.8.8.8:53", + want_host: "8.8.8.8", + want_port: Some(53), + want_path: "/grpc.io/docs", + want_str: "dns://8.8.8.8:53/grpc.io/docs", + }, + TestCase { + input: "unix:path/to/file", + want_scheme: "unix", + want_host_port: "", + want_host: "", + want_port: None, + want_path: "path/to/file", + want_str: "unix://path/to/file", + }, + TestCase { + input: "unix:///run/containerd/containerd.sock", + want_scheme: "unix", + want_host_port: "", + want_host: "", + want_port: None, + want_path: "/run/containerd/containerd.sock", + want_str: "unix:///run/containerd/containerd.sock", + }, + ]; + + for tc in test_cases { + let target: Target = tc.input.parse().unwrap(); + assert_eq!(target.scheme(), tc.want_scheme); + assert_eq!(target.authority_host(), tc.want_host); + assert_eq!(target.authority_port(), tc.want_port); + assert_eq!(target.authority_host_port(), tc.want_host_port); + assert_eq!(target.path(), tc.want_path); + assert_eq!(&target.to_string(), tc.want_str); + } + } } diff --git a/grpc/src/client/name_resolution/registry.rs b/grpc/src/client/name_resolution/registry.rs new file mode 100644 index 000000000..fa4a42163 --- /dev/null +++ b/grpc/src/client/name_resolution/registry.rs @@ -0,0 +1,73 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{ + collections::HashMap, + sync::{Arc, Mutex, OnceLock}, +}; + +use super::ResolverBuilder; + +static GLOBAL_RESOLVER_REGISTRY: OnceLock = OnceLock::new(); + +/// A registry to store and retrieve name resolvers. Resolvers are indexed by +/// the URI scheme they are intended to handle. +#[derive(Default)] +pub struct ResolverRegistry { + m: Arc>>>, +} + +impl ResolverRegistry { + /// Construct an empty name resolver registry. + fn new() -> Self { + Self { m: Arc::default() } + } + + /// Add a name resolver into the registry. builder.scheme() will + // be used as the scheme registered with this builder. If multiple + // resolvers are registered with the same name, the one registered last + // will take effect. Panics if the given scheme contains uppercase + // characters. + pub fn add_builder(&self, builder: Box) { + let scheme = builder.scheme(); + if scheme.chars().any(|c| c.is_ascii_uppercase()) { + panic!("Scheme must not contain uppercase characters: {}", scheme); + } + self.m + .lock() + .unwrap() + .insert(scheme.to_string(), Arc::from(builder)); + } + + /// Returns the resolver builder registered for the given scheme, if any. + /// + /// The provided scheme is case-insensitive; any uppercase characters + /// will be converted to lowercase before lookup. + pub fn get(&self, scheme: &str) -> Option> { + self.m + .lock() + .unwrap() + .get(&scheme.to_lowercase()) + .map(|b| b.clone()) + } +} + +/// Global registry for resolver builders. +pub fn global_registry() -> &'static ResolverRegistry { + GLOBAL_RESOLVER_REGISTRY.get_or_init(|| ResolverRegistry::new()) +} diff --git a/grpc/src/client/service_config.rs b/grpc/src/client/service_config.rs index 3639c03e5..1d5ad153a 100644 --- a/grpc/src/client/service_config.rs +++ b/grpc/src/client/service_config.rs @@ -18,5 +18,5 @@ /// An in-memory representation of a service config, usually provided to gRPC as /// a JSON object. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub(crate) struct ServiceConfig; diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index 244a7a904..a63f4d323 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -27,6 +27,7 @@ #![allow(dead_code)] pub mod client; +mod rt; pub mod service; pub(crate) mod attributes; diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs new file mode 100644 index 000000000..e7b0405a1 --- /dev/null +++ b/grpc/src/rt/mod.rs @@ -0,0 +1,69 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{future::Future, pin::Pin}; + +pub mod tokio; + +/// An abstraction over an asynchronous runtime. +/// +/// The `Runtime` trait defines the core functionality required for +/// executing asynchronous tasks, creating DNS resolvers, and performing +/// time-based operations such as sleeping. It provides a uniform interface +/// that can be implemented for various async runtimes, enabling pluggable +/// and testable infrastructure. +pub trait Runtime: Send + Sync { + /// Spawns the given asynchronous task to run in the background. + fn spawn( + &self, + task: Pin + Send + 'static>>, + ) -> Box; + + /// Creates and returns an instance of a DNSResolver, optionally + /// configured by the ResolverOptions struct. This method may return an + /// error if it fails to create the DNSResolver. + fn get_dns_resolver(&self, opts: ResolverOptions) -> Result, String>; + + /// Returns a future that completes after the specified duration. + fn sleep(&self, duration: std::time::Duration) -> Pin>; +} + +/// A future that resolves after a specified duration. +pub trait Sleep: Send + Sync + Future {} + +pub trait TaskHandle: Send + Sync { + /// Abort the associated task. + fn abort(&self); +} + +/// A trait for asynchronous DNS resolution. +#[tonic::async_trait] +pub trait DnsResolver: Send + Sync { + /// Resolve an address + async fn lookup_host_name(&self, name: &str) -> Result, String>; + /// Perform a TXT record lookup. If a txt record contains multiple strings, + /// they are concatenated. + async fn lookup_txt(&self, name: &str) -> Result, String>; +} + +#[derive(Default)] +pub struct ResolverOptions { + /// The address of the DNS server in "IP:port" format. If None, the + /// system's default DNS server will be used. + pub server_addr: Option, +} diff --git a/grpc/src/rt/tokio/hickory_resolver.rs b/grpc/src/rt/tokio/hickory_resolver.rs new file mode 100644 index 000000000..13b20b010 --- /dev/null +++ b/grpc/src/rt/tokio/hickory_resolver.rs @@ -0,0 +1,234 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}; + +/// A DNS resolver that uses hickory with the tokio runtime. This supports txt +/// lookups in addition to A and AAAA record lookups. It also supports using +/// custom DNS servers. +pub struct DnsResolver { + resolver: hickory_resolver::TokioResolver, +} + +#[tonic::async_trait] +impl super::DnsResolver for DnsResolver { + async fn lookup_host_name(&self, name: &str) -> Result, String> { + let response = self + .resolver + .lookup_ip(name) + .await + .map_err(|err| err.to_string())?; + Ok(response.iter().collect()) + } + + async fn lookup_txt(&self, name: &str) -> Result, String> { + let response: Vec<_> = self + .resolver + .txt_lookup(name) + .await + .map_err(|err| err.to_string())? + .iter() + .map(|txt_record| { + txt_record + .iter() + .map(|bytes| String::from_utf8_lossy(bytes).into_owned()) + .collect::>() + .join("") + }) + .collect(); + Ok(response) + } +} + +impl DnsResolver { + pub fn new(opts: super::ResolverOptions) -> Result { + let builder = if let Some(server_addr) = opts.server_addr { + let provider = hickory_resolver::name_server::TokioConnectionProvider::default(); + let name_servers = NameServerConfigGroup::from_ips_clear( + &[server_addr.ip()], + server_addr.port(), + true, + ); + let config = ResolverConfig::from_parts(None, vec![], name_servers); + hickory_resolver::TokioResolver::builder_with_config(config, provider) + } else { + hickory_resolver::TokioResolver::builder_tokio().map_err(|err| err.to_string())? + }; + let mut resolver_opts = ResolverOpts::default(); + resolver_opts.ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6; + Ok(DnsResolver { + resolver: builder.with_options(resolver_opts).build(), + }) + } +} + +#[cfg(test)] +mod tests { + use std::{ + net::{Ipv4Addr, SocketAddr}, + sync::Arc, + }; + + use hickory_resolver::Name; + use hickory_server::{ + authority::{Catalog, ZoneType}, + proto::rr::{ + rdata::{A, TXT}, + LowerName, RData, Record, + }, + store::in_memory::InMemoryAuthority, + ServerFuture, + }; + use tokio::{net::UdpSocket, sync::oneshot, task::JoinHandle}; + + use crate::rt::{tokio::TokioDefaultDnsResolver, DnsResolver, ResolverOptions}; + + #[tokio::test] + async fn compare_hickory_and_default() { + let hickory_dns = super::DnsResolver::new(ResolverOptions::default()).unwrap(); + let mut ips_hickory = hickory_dns.lookup_host_name("localhost").await.unwrap(); + + let default_resolver = TokioDefaultDnsResolver::new(ResolverOptions::default()).unwrap(); + + let mut system_resolver_ips = default_resolver + .lookup_host_name("localhost") + .await + .unwrap(); + + // Hickory requests A and AAAA records in parallel, so the order of IPv4 + // and IPv6 addresses isn't deterministic. + ips_hickory.sort(); + system_resolver_ips.sort(); + assert_eq!( + ips_hickory, system_resolver_ips, + "both resolvers should produce same IPs for localhost" + ) + } + + #[tokio::test] + async fn resolve_txt() { + let records = vec![ + Record::from_rdata( + Name::from_ascii("test.local.").unwrap(), + 300, + RData::TXT(TXT::new(vec![ + "one".to_string(), + "two".to_string(), + "three".to_string(), + ])), + ), + Record::from_rdata( + Name::from_ascii("test.local.").unwrap(), + 300, + RData::TXT(TXT::new(vec![ + "abc".to_string(), + "def".to_string(), + "ghi".to_string(), + ])), + ), + ]; + + let dns = start_in_memory_dns_server("test.local.", records).await; + let opts = ResolverOptions { + server_addr: Some(dns.addr), + }; + let hickory_dns = super::DnsResolver::new(opts).unwrap(); + + let txt = hickory_dns.lookup_txt("test.local").await.unwrap(); + assert_eq!( + txt, + vec!["onetwothree".to_string(), "abcdefghi".to_string(),] + ); + dns.shutdown().await; + } + + #[tokio::test] + async fn custom_authority() { + let record = Record::from_rdata( + Name::from_ascii("test.local.").unwrap(), + 300, + RData::A(A(Ipv4Addr::new(1, 2, 3, 4))), + ); + let dns = start_in_memory_dns_server("test.local.", vec![record]).await; + let opts = ResolverOptions { + server_addr: Some(dns.addr), + }; + let hickory_dns = super::DnsResolver::new(opts).unwrap(); + let ips = hickory_dns.lookup_host_name("test.local").await.unwrap(); + assert_eq!(ips, vec![Ipv4Addr::new(1, 2, 3, 4)]); + dns.shutdown().await + } + + struct FakeDns { + tx: Option>, + join_handle: Option>, + addr: SocketAddr, + } + + impl FakeDns { + async fn shutdown(mut self) { + let tx = self.tx.take().unwrap(); + tx.send(()).unwrap(); + let handle = self.join_handle.take().unwrap(); + handle.await.unwrap(); + } + } + + /// Starts an in-memory DNS server with and adds the given records. Returns + /// a DNS server which should be shutdown after the test. It uses a random + /// port to bind since tests can run in parallel. The assigned port can be + /// read from the returned struct. + async fn start_in_memory_dns_server(host: &str, records: Vec) -> FakeDns { + // Create a simple A record for `test.local.` + let authority = + InMemoryAuthority::empty(Name::from_ascii(host).unwrap(), ZoneType::Primary, false); + + for record in records { + authority.upsert(record, 0).await; + } + + let mut catalog = Catalog::new(); + catalog.upsert( + LowerName::new(&Name::from_ascii(host).unwrap()), + vec![Arc::new(authority)], + ); + + let mut server = ServerFuture::new(catalog); + + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + server.register_socket(socket); + + println!("DNS server running on {}", addr); + + let (tx, rx) = oneshot::channel::<()>(); + let server_task = tokio::spawn(async move { + tokio::select! { + _ = server.block_until_done() => {}, + _ = rx => { + server.shutdown_gracefully().await.unwrap(); + } + } + }); + FakeDns { + tx: Some(tx), + join_handle: Some(server_task), + addr, + } + } +} diff --git a/grpc/src/rt/tokio/mod.rs b/grpc/src/rt/tokio/mod.rs new file mode 100644 index 000000000..fbebd6f78 --- /dev/null +++ b/grpc/src/rt/tokio/mod.rs @@ -0,0 +1,127 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{future::Future, net::SocketAddr, pin::Pin}; + +use super::{DnsResolver, ResolverOptions, Runtime, Sleep, TaskHandle}; + +#[cfg(feature = "hickory_dns")] +mod hickory_resolver; + +/// A DNS resolver that uses tokio::net::lookup_host for resolution. It only +/// supports host lookups. +pub struct TokioDefaultDnsResolver {} + +#[tonic::async_trait] +impl DnsResolver for TokioDefaultDnsResolver { + async fn lookup_host_name(&self, name: &str) -> Result, String> { + let name_with_port = match name.parse::() { + Ok(ip) => SocketAddr::new(ip, 0).to_string(), + Err(_) => format!("{}:0", name), + }; + let ips = tokio::net::lookup_host(name_with_port) + .await + .map_err(|err| err.to_string())? + .map(|socket_addr| socket_addr.ip()) + .collect(); + Ok(ips) + } + + async fn lookup_txt(&self, _name: &str) -> Result, String> { + Err("TXT record lookup unavailable. Enable the optional 'hickory_dns' feature to enable service config lookups.".to_string()) + } +} + +pub struct TokioRuntime {} + +impl TaskHandle for tokio::task::JoinHandle<()> { + fn abort(&self) { + self.abort() + } +} + +impl Sleep for tokio::time::Sleep {} + +impl Runtime for TokioRuntime { + fn spawn( + &self, + task: Pin + Send + 'static>>, + ) -> Box { + Box::new(tokio::spawn(task)) + } + + fn get_dns_resolver(&self, opts: ResolverOptions) -> Result, String> { + #[cfg(feature = "hickory_dns")] + { + Ok(Box::new(hickory_resolver::DnsResolver::new(opts)?)) + } + #[cfg(not(feature = "hickory_dns"))] + { + Ok(Box::new(TokioDefaultDnsResolver::new(opts)?)) + } + } + + fn sleep(&self, duration: std::time::Duration) -> Pin> { + Box::pin(tokio::time::sleep(duration)) + } +} + +impl TokioDefaultDnsResolver { + pub fn new(opts: ResolverOptions) -> Result { + if opts.server_addr.is_some() { + return Err("Custom DNS server are not supported, enable optional feature 'hickory_dns' to enable support.".to_string()); + } + Ok(TokioDefaultDnsResolver {}) + } +} + +#[cfg(test)] +mod tests { + use super::{DnsResolver, ResolverOptions, Runtime, TokioDefaultDnsResolver, TokioRuntime}; + + #[tokio::test] + async fn lookup_hostname() { + let runtime = TokioRuntime {}; + + let dns = runtime + .get_dns_resolver(ResolverOptions::default()) + .unwrap(); + let ips = dns.lookup_host_name("localhost").await.unwrap(); + assert!( + !ips.is_empty(), + "Expect localhost to resolve to more than 1 IPs." + ) + } + + #[tokio::test] + async fn default_resolver_txt_fails() { + let default_resolver = TokioDefaultDnsResolver::new(ResolverOptions::default()).unwrap(); + + let txt = default_resolver.lookup_txt("google.com").await; + assert!(txt.is_err()) + } + + #[tokio::test] + async fn default_resolver_custom_authority() { + let opts = ResolverOptions { + server_addr: Some("8.8.8.8:53".parse().unwrap()), + }; + let default_resolver = TokioDefaultDnsResolver::new(opts); + assert!(default_resolver.is_err()) + } +}