From 07eb017de2fd211dc5c2b186394ba949e70ab927 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 22 May 2025 19:23:49 +0530 Subject: [PATCH 1/7] Add name resolution API --- grpc/Cargo.toml | 12 +- grpc/src/client/mod.rs | 2 +- grpc/src/client/name_resolution/backoff.rs | 225 +++++++++ grpc/src/client/name_resolution/dns/mod.rs | 380 +++++++++++++++ grpc/src/client/name_resolution/dns/test.rs | 505 ++++++++++++++++++++ grpc/src/client/name_resolution/mod.rs | 328 ++++++++++--- grpc/src/client/name_resolution/registry.rs | 70 +++ grpc/src/client/service_config.rs | 2 +- grpc/src/lib.rs | 1 + grpc/src/rt/mod.rs | 69 +++ grpc/src/rt/tokio/hickory_resolver.rs | 234 +++++++++ grpc/src/rt/tokio/mod.rs | 127 +++++ 12 files changed, 1883 insertions(+), 72 deletions(-) create mode 100644 grpc/src/client/name_resolution/backoff.rs create mode 100644 grpc/src/client/name_resolution/dns/mod.rs create mode 100644 grpc/src/client/name_resolution/dns/test.rs create mode 100644 grpc/src/client/name_resolution/registry.rs create mode 100644 grpc/src/rt/mod.rs create mode 100644 grpc/src/rt/tokio/hickory_resolver.rs create mode 100644 grpc/src/rt/tokio/mod.rs diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index a82569bca..017e6e37f 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -7,6 +7,16 @@ license = "Apache-2.0" [dependencies] url = "2.5.0" -tokio = { version = "1.37.0", features = ["sync"] } +tokio = { version = "1.37.0", features = ["sync", "rt", "net", "time", "macros"] } tonic = { version = "0.13.0", path = "../tonic", default-features = false, features = ["codegen"] } futures-core = "0.3.31" +once_cell = "1.19.0" +hickory-resolver = { version = "0.25.1", optional = true } +rand = "0.8.5" + +[dev-dependencies] +hickory-server = "0.25.2" + +[features] +default = ["hickory_dns"] +hickory_dns = ["dep:hickory-resolver"] diff --git a/grpc/src/client/mod.rs b/grpc/src/client/mod.rs index c9e2365ce..4108411b3 100644 --- a/grpc/src/client/mod.rs +++ b/grpc/src/client/mod.rs @@ -19,9 +19,9 @@ use std::fmt::Display; pub mod channel; -pub mod service; pub(crate) mod load_balancing; pub(crate) mod name_resolution; +pub mod service; pub mod service_config; /// A representation of the current state of a gRPC channel, also used for the diff --git a/grpc/src/client/name_resolution/backoff.rs b/grpc/src/client/name_resolution/backoff.rs new file mode 100644 index 000000000..a0709ef8f --- /dev/null +++ b/grpc/src/client/name_resolution/backoff.rs @@ -0,0 +1,225 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use rand::Rng; +use std::{sync::Mutex, time::Duration}; + +#[derive(Clone)] +pub struct BackoffConfig { + /// The amount of time to backoff after the first failure. + pub base_delay: Duration, + + /// The factor with which to multiply backoffs after a + /// failed retry. Should ideally be greater than 1. + pub multiplier: f64, + + /// The factor with which backoffs are randomized. + pub jitter: f64, + + /// The upper bound of backoff delay. + pub max_delay: Duration, +} + +pub struct ExponentialBackoff { + config: BackoffConfig, + + /// The delay for the next retry, without the random jitter. Store as f64 + /// to avoid rounding errors. + next_delay_secs: Mutex, +} + +/// This is a backoff configuration with the default values specified +/// at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md. +/// +/// This should be useful for callers who want to configure backoff with +/// non-default values only for a subset of the options. +pub const DEFAULT_EXPONENTIAL_CONFIG: BackoffConfig = BackoffConfig { + base_delay: Duration::from_secs(1), + multiplier: 1.6, + jitter: 0.2, + max_delay: Duration::from_secs(120), +}; + +impl ExponentialBackoff { + pub fn new(mut config: BackoffConfig) -> Self { + // Adjust params to get them in valid ranges. + // 0 <= base_dealy <= max_delay + config.base_delay = config.base_delay.min(config.max_delay); + // 1 <= multiplier + config.multiplier = config.multiplier.max(1.0); + // 0 <= jitter <= 1 + config.jitter = config.jitter.max(0.0); + config.jitter = config.jitter.min(1.0); + let next_delay_secs = config.base_delay.as_secs_f64(); + ExponentialBackoff { + config, + next_delay_secs: Mutex::new(next_delay_secs), + } + } + + pub fn reset(&self) { + let mut next_delay = self.next_delay_secs.lock().unwrap(); + *next_delay = self.config.base_delay.as_secs_f64(); + } + + pub fn backoff_duration(&self) -> Duration { + let mut next_delay = self.next_delay_secs.lock().unwrap(); + let cur_delay = + *next_delay * (1.0 + self.config.jitter * rand::thread_rng().gen_range(-1.0..1.0)); + *next_delay = self + .config + .max_delay + .as_secs_f64() + .min(*next_delay * self.config.multiplier); + Duration::from_secs_f64(cur_delay) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use crate::client::name_resolution::backoff::{BackoffConfig, ExponentialBackoff}; + + // Epsilon for floating point comparisons if needed, though Duration + // comparisons are often better. + const EPSILON: f64 = 1e-9; + + #[test] + fn base_less_than_max() { + let config = BackoffConfig { + base_delay: Duration::from_secs(10), + multiplier: 123.0, + jitter: 0.0, + max_delay: Duration::from_secs(100), + }; + let backoff = ExponentialBackoff::new(config.clone()); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + } + + #[test] + fn base_more_than_max() { + let config = BackoffConfig { + multiplier: 123.0, + jitter: 0.0, + base_delay: Duration::from_secs(100), + max_delay: Duration::from_secs(10), + }; + let backoff = ExponentialBackoff::new(config.clone()); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + } + + #[test] + fn negative_multiplier() { + let config = BackoffConfig { + multiplier: -123.0, + jitter: 0.0, + base_delay: Duration::from_secs(10), + max_delay: Duration::from_secs(100), + }; + let backoff = ExponentialBackoff::new(config.clone()); + // multiplier gets clipped to 1. + assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + } + + #[test] + fn negative_jitter() { + let config = BackoffConfig { + multiplier: 1.0, + jitter: -10.0, + base_delay: Duration::from_secs(10), + max_delay: Duration::from_secs(100), + }; + let backoff = ExponentialBackoff::new(config.clone()); + // jitter gets clipped to 0. + assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + } + + #[test] + fn jitter_greater_than_one() { + let config = BackoffConfig { + multiplier: 1.0, + jitter: 2.0, + base_delay: Duration::from_secs(10), + max_delay: Duration::from_secs(100), + }; + let backoff = ExponentialBackoff::new(config.clone()); + // jitter gets clipped to 1. + // 0 <= duration <= 20. + let duration = backoff.backoff_duration(); + assert_eq!(duration.lt(&Duration::from_secs(20)), true); + assert_eq!(duration.gt(&Duration::from_secs(0)), true); + + let duration = backoff.backoff_duration(); + assert_eq!(duration.lt(&Duration::from_secs(20)), true); + assert_eq!(duration.gt(&Duration::from_secs(0)), true); + } + + #[test] + fn backoff_reset_no_jitter() { + let config = BackoffConfig { + multiplier: 2.0, + jitter: 0.0, + base_delay: Duration::from_secs(1), + max_delay: Duration::from_secs(15), + }; + let backoff = ExponentialBackoff::new(config.clone()); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(1)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(2)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(4)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(8)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + // Duration is capped to max_delay. + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + + // reset and repeat. + backoff.reset(); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(1)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(2)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(4)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(8)); + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + // Duration is capped to max_delay. + assert_eq!(backoff.backoff_duration(), Duration::from_secs(15)); + } + + #[test] + fn backoff_with_jitter() { + let config = BackoffConfig { + multiplier: 2.0, + jitter: 0.2, + base_delay: Duration::from_secs(1), + max_delay: Duration::from_secs(15), + }; + let backoff = ExponentialBackoff::new(config.clone()); + // 0.8 <= duration <= 1.2. + let duration = backoff.backoff_duration(); + assert_eq!(duration.gt(&Duration::from_secs_f64(0.8 - EPSILON)), true); + assert_eq!(duration.lt(&Duration::from_secs_f64(1.2 + EPSILON)), true); + // 1.6 <= duration <= 2.4. + let duration = backoff.backoff_duration(); + assert_eq!(duration.gt(&Duration::from_secs_f64(1.6 - EPSILON)), true); + assert_eq!(duration.lt(&Duration::from_secs_f64(2.4 + EPSILON)), true); + // 3.2 <= duration <= 4.8. + let duration = backoff.backoff_duration(); + assert_eq!(duration.gt(&Duration::from_secs_f64(3.2 - EPSILON)), true); + assert_eq!(duration.lt(&Duration::from_secs_f64(4.8 + EPSILON)), true); + } +} diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs new file mode 100644 index 000000000..14f6b0393 --- /dev/null +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -0,0 +1,380 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +//! This module implements a DNS resolver to be installed as the default resolver +//! in grpc. + +use std::{ + net::{IpAddr, SocketAddr}, + sync::{Arc, Mutex}, + time::{Duration, SystemTime}, +}; + +use once_cell::sync::Lazy; +use tokio::sync::mpsc::UnboundedSender; +use url::Host; + +use crate::{ + client::name_resolution::{Address, NopResolver, ResolverUpdate, TCP_IP_NETWORK_TYPE}, + rt, +}; + +use super::{ + backoff::{BackoffConfig, ExponentialBackoff, DEFAULT_EXPONENTIAL_CONFIG}, + Endpoint, Resolver, ResolverBuilder, GLOBAL_RESOLVER_REGISTRY, +}; + +#[cfg(test)] +mod test; + +const DEFAULT_PORT: u16 = 443; +const DEFAULT_DNS_PORT: u16 = 53; + +/// This specifies the maximum duration for a DNS resolution request. +/// If the timeout expires before a response is received, the request will be +/// canceled. +/// +/// It is recommended to set this value at application startup. Avoid modifying +/// this variable after initialization. +static RESOLVING_TIMEOUT: Lazy> = Lazy::new(|| Mutex::new(Duration::from_secs(30))); + +/// This is the minimum interval at which re-resolutions are allowed. This helps +/// to prevent excessive re-resolution. +static MIN_RESOLUTION_INTERVAL: Lazy> = + Lazy::new(|| Mutex::new(Duration::from_secs(30))); + +pub fn get_resolving_timeout() -> Duration { + RESOLVING_TIMEOUT.lock().unwrap().clone() +} + +pub fn set_resolving_timeout(duration: Duration) { + *RESOLVING_TIMEOUT.lock().unwrap() = duration; +} + +pub fn get_min_resolution_interval() -> Duration { + MIN_RESOLUTION_INTERVAL.lock().unwrap().clone() +} + +pub fn set_min_resolution_interval(duration: Duration) { + *MIN_RESOLUTION_INTERVAL.lock().unwrap() = duration; +} + +pub fn reg() { + GLOBAL_RESOLVER_REGISTRY.add_builder(Box::new(Builder {})); +} + +struct Builder {} + +struct DnsOptions { + min_resolution_interval: Duration, + resolving_timeout: Duration, + backoff_config: BackoffConfig, +} + +impl DnsResolver { + fn new( + target: &super::Target, + options: super::ResolverOptions, + dns_opts: DnsOptions, + ) -> Box { + let parsed = match parse_endpoint_and_authority(target) { + Ok(res) => res, + Err(err) => return nop_resolver_for_err(err.to_string(), options), + }; + let endpoint = parsed.endpoint; + let host = match endpoint.host { + Host::Domain(d) => d, + Host::Ipv4(ipv4) => { + return nop_resolver_for_ip(IpAddr::V4(ipv4), endpoint.port, options) + } + Host::Ipv6(ipv6) => { + return nop_resolver_for_ip(IpAddr::V6(ipv6), endpoint.port, options) + } + }; + let authority = parsed.authority; + let dns = match options.runtime.get_dns_resolver(rt::ResolverOptions { + server_addr: authority, + }) { + Ok(dns) => dns, + Err(err) => return nop_resolver_for_err(err.to_string(), options), + }; + let state = Arc::new(Mutex::new(InternalState { + addrs: Ok(Vec::new()), + })); + let state_copy = state.clone(); + let (resolve_now_tx, mut resolve_now_rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + let (update_error_tx, update_error_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + + let handle = options.runtime.clone().spawn(Box::pin(async move { + let backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone()); + let state = state_copy; + let work_scheduler = options.work_scheduler; + let mut update_error_rx = update_error_rx; + loop { + let mut lookup_fut = dns.lookup_host_name(&host); + let mut timeout_fut = options.runtime.sleep(dns_opts.resolving_timeout); + let addrs = tokio::select! { + result = &mut lookup_fut => { + match result { + Ok(ips) => { + let addrs = ips + .into_iter() + .map(|ip| SocketAddr::new(ip, endpoint.port)) + .collect(); + Ok(addrs) + } + Err(err) => Err(err), + } + } + _ = &mut timeout_fut => { + Err("Timed out waiting for DNS resolution".to_string()) + } + }; + { + let mut internal_state = match state.lock() { + Ok(state) => state, + Err(_) => return, + }; + internal_state.addrs = addrs; + } + work_scheduler.schedule_work(); + let update_result = match update_error_rx.recv().await { + Some(res) => res, + None => return, + }; + let next_resoltion_time: SystemTime; + if update_result.is_err() { + next_resoltion_time = SystemTime::now() + .checked_add(backoff.backoff_duration()) + .unwrap(); + } else { + // Success resolving, wait for the next resolve_now. However, + // also wait MIN_RESOLUTION_INTERVAL at the very least to prevent + // constantly re-resolving. + backoff.reset(); + next_resoltion_time = SystemTime::now() + .checked_add(dns_opts.min_resolution_interval) + .unwrap(); + _ = resolve_now_rx.recv().await; + } + // Wait till next resolution time. + let duration = match next_resoltion_time.duration_since(SystemTime::now()) { + Ok(d) => d, + Err(_) => continue, // Time has already passed. + }; + options.runtime.sleep(duration).await; + } + })); + + Box::new(DnsResolver { + state, + task_handle: handle, + resolve_now_requester: resolve_now_tx, + update_error_sender: update_error_tx, + }) + } +} + +impl ResolverBuilder for Builder { + fn build( + &self, + target: &super::Target, + options: super::ResolverOptions, + ) -> Box { + let dns_opts = DnsOptions { + min_resolution_interval: get_min_resolution_interval(), + resolving_timeout: get_resolving_timeout(), + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + DnsResolver::new(target, options, dns_opts) + } + + fn scheme(&self) -> &'static str { + "dns" + } + + fn is_valid_uri(&self, target: &super::Target) -> bool { + if let Err(err) = parse_endpoint_and_authority(target) { + eprintln!("{}", err); + false + } else { + true + } + } +} + +struct DnsResolver { + state: Arc>, + task_handle: Box, + resolve_now_requester: UnboundedSender<()>, + update_error_sender: UnboundedSender>, +} + +struct InternalState { + addrs: Result, String>, +} + +impl Resolver for DnsResolver { + fn resolve_now(&mut self) { + _ = self.resolve_now_requester.send(()); + } + + fn work(&mut self, channel_controller: &mut dyn super::ChannelController) { + let state = match self.state.lock() { + Err(_) => { + eprintln!("DNS resolver mutex poisoned, can't update channel"); + return; + } + Ok(s) => s, + }; + let endpoint_result = match &state.addrs { + Ok(addrs) => { + let endpoints: Vec<_> = addrs + .iter() + .map(|a| Endpoint { + addresses: vec![Address { + network_type: TCP_IP_NETWORK_TYPE, + address: a.to_string(), + ..Default::default() + }], + ..Default::default() + }) + .collect(); + Ok(endpoints) + } + Err(err) => Err(err.to_string()), + }; + let update = ResolverUpdate { + endpoints: endpoint_result, + ..Default::default() + }; + let status = channel_controller.update(update); + _ = self.update_error_sender.send(status); + } +} + +impl Drop for DnsResolver { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +#[derive(Eq, PartialEq, Debug)] +struct HostPort { + host: Host, + port: u16, +} + +#[derive(Eq, PartialEq, Debug)] +struct ParseResult { + endpoint: HostPort, + authority: Option, +} + +fn parse_endpoint_and_authority(target: &super::Target) -> Result { + // Parse the endpoint. + let endpoint = target.path(); + let endpoint = endpoint.strip_prefix("/").unwrap_or(endpoint); + let parse_result = parse_host_port(endpoint, DEFAULT_PORT) + .map_err(|err| format!("Failed to parse target {}: {}", target, err))?; + let endpoint = parse_result.ok_or("Received empty endpoint host.".to_string())?; + + // Parse the authority. + let authority = target.authority_host_port(); + if authority.is_empty() { + return Ok(ParseResult { + endpoint, + authority: None, + }); + } + let parse_result = parse_host_port(&authority, DEFAULT_DNS_PORT) + .map_err(|err| format!("Failed to parse DNS authority {}: {}", target, err))?; + let Some(authority) = parse_result else { + return Ok(ParseResult { + endpoint, + authority: None, + }); + }; + let authority = match authority.host { + Host::Ipv4(ipv4) => SocketAddr::new(IpAddr::V4(ipv4), authority.port), + Host::Ipv6(ipv6) => SocketAddr::new(IpAddr::V6(ipv6), authority.port), + _ => { + return Err(format!("Received non-IP DNS authority {}", authority.host)); + } + }; + Ok(ParseResult { + endpoint, + authority: Some(authority), + }) +} + +/// Takes the user input string of the format "host:port" and default port, +/// returns the parsed host and port. If string doesn't specify a port, the +/// default_port is returned. If the string doesn't specify the host, +/// Result is returned. +fn parse_host_port(host_and_port: &str, default_port: u16) -> Result, String> { + // We need to use the https scheme otherwise url::Url::parse doesn't convert + // IP addresses to Host::Ipv4 or Host::Ipv6 if they could represent valid + // domains. + let url = format!("https://{}", host_and_port); + let url = url.parse::().map_err(|err| err.to_string())?; + let port = url.port().unwrap_or(default_port); + let host = match url.host() { + Some(host) => host, + None => return Ok(None), + }; + // Convert the domain to an owned string. + let host = match host { + Host::Domain(s) => Host::Domain(s.to_owned()), + Host::Ipv4(ip) => Host::Ipv4(ip), + Host::Ipv6(ip) => Host::Ipv6(ip), + }; + Ok(Some(HostPort { host, port })) +} + +fn nop_resolver_for_ip( + ip: IpAddr, + port: u16, + options: super::ResolverOptions, +) -> Box { + options.work_scheduler.schedule_work(); + Box::new(NopResolver { + update: ResolverUpdate { + endpoints: Ok(vec![Endpoint { + addresses: vec![Address { + network_type: TCP_IP_NETWORK_TYPE, + address: SocketAddr::new(ip, port).to_string(), + ..Default::default() + }], + ..Default::default() + }]), + ..Default::default() + }, + }) +} + +fn nop_resolver_for_err(err: String, options: super::ResolverOptions) -> Box { + options.work_scheduler.schedule_work(); + Box::new(NopResolver { + update: ResolverUpdate { + endpoints: Err(err), + ..Default::default() + }, + }) +} diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs new file mode 100644 index 000000000..6ef1077ea --- /dev/null +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -0,0 +1,505 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{future::Future, pin::Pin, sync::Arc, time::Duration}; + +use tokio::sync::mpsc::{self, UnboundedSender}; + +use crate::{ + client::name_resolution::{ + self, + backoff::{BackoffConfig, DEFAULT_EXPONENTIAL_CONFIG}, + dns::{parse_endpoint_and_authority, HostPort}, + ResolverOptions, ResolverUpdate, Target, GLOBAL_RESOLVER_REGISTRY, + }, + rt::{self, tokio::TokioRuntime}, +}; + +use super::ParseResult; + +const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10); + +#[test] +pub fn target_parsing() { + struct TestCase { + input: &'static str, + want_result: Result, + } + let test_cases = vec![ + TestCase { + input: "dns:///grpc.io", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 443, + }, + authority: None, + }), + }, + TestCase { + input: "dns:///grpc.io:1234", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: None, + }), + }, + TestCase { + input: "dns://8.8.8.8/grpc.io:1234", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: Some("8.8.8.8:53".parse().unwrap()), + }), + }, + TestCase { + input: "dns://8.8.8.8:5678/grpc.io:1234/abc", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: Some("8.8.8.8:5678".parse().unwrap()), + }), + }, + TestCase { + input: "dns://[::1]:5678/grpc.io:1234/abc", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 1234, + }, + authority: Some("[::1]:5678".parse().unwrap()), + }), + }, + TestCase { + input: "dns://[fe80::1]:5678/127.0.0.1:1234/abc", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Ipv4("127.0.0.1".parse().unwrap()), + port: 1234, + }, + authority: Some("[fe80::1]:5678".parse().unwrap()), + }), + }, + TestCase { + input: "dns:///[fe80::1%80]:5678/abc", + want_result: Err("SocketAddr doesn't support IPv6 addresses with zones".to_string()), + }, + TestCase { + input: "dns:///:5678/abc", + want_result: Err("Empty host with port".to_string()), + }, + TestCase { + input: "dns:///grpc.io:abc/abc", + want_result: Err("Non numeric port".to_string()), + }, + TestCase { + input: "dns:///grpc.io:/", + want_result: Ok(ParseResult { + endpoint: HostPort { + host: url::Host::Domain("grpc.io".to_string()), + port: 443, + }, + authority: None, + }), + }, + TestCase { + input: "dns:///:", + want_result: Err("No host and port".to_string()), + }, + TestCase { + input: "dns:///[2001:db8:a0b:12f0::1", + want_result: Err("Invalid address".to_string()), + }, + ]; + + for tc in test_cases { + let target: Target = tc.input.parse().unwrap(); + let got = parse_endpoint_and_authority(&target); + if got.is_err() != tc.want_result.is_err() { + panic!( + "Got error {:?}, want error: {:?}", + got.err(), + tc.want_result.err() + ); + } + if got.is_err() { + continue; + } + assert_eq!(got.unwrap(), tc.want_result.unwrap()); + } +} + +struct WorkScheduler { + work_tx: UnboundedSender<()>, +} + +impl name_resolution::WorkScheduler for WorkScheduler { + fn schedule_work(&self) { + self.work_tx.send(()).unwrap(); + } +} + +struct FakeChannelController { + update_result: Result<(), String>, + update_tx: UnboundedSender, +} + +impl name_resolution::ChannelController for FakeChannelController { + fn update(&mut self, update: name_resolution::ResolverUpdate) -> Result<(), String> { + println!("Received resolver update: {:?}", &update); + self.update_tx.send(update).unwrap(); + self.update_result.clone() + } + + fn parse_service_config( + &self, + _: &str, + ) -> Result { + Err("Unimplemented".to_string()) + } +} + +#[tokio::test] +pub async fn dns_basic() { + super::reg(); + let builder = GLOBAL_RESOLVER_REGISTRY.get("dns").unwrap(); + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let mut resolver = builder.build(target, opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // A successful endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); +} + +#[tokio::test] +pub async fn invalid_target() { + super::reg(); + let builder = GLOBAL_RESOLVER_REGISTRY.get("dns").unwrap(); + let target = &"dns:///:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let mut resolver = builder.build(target, opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // An error endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!( + update + .endpoints + .err() + .unwrap() + .contains(&target.to_string()), + true + ); +} + +#[derive(Clone)] +struct FakeDns { + latency: Duration, + lookup_result: Result, String>, +} + +#[tonic::async_trait] +impl rt::DnsResolver for FakeDns { + async fn lookup_host_name(&self, _: &str) -> Result, String> { + tokio::time::sleep(self.latency).await; + self.lookup_result.clone() + } + + async fn lookup_txt(&self, _: &str) -> Result, String> { + Err("unimplemented".to_string()) + } +} + +struct FakeRuntime { + inner: TokioRuntime, + dns: FakeDns, +} + +impl rt::Runtime for FakeRuntime { + fn spawn( + &self, + task: Pin + Send + 'static>>, + ) -> Box { + self.inner.spawn(task) + } + + fn get_dns_resolver(&self, _: rt::ResolverOptions) -> Result, String> { + Ok(Box::new(self.dns.clone())) + } + + fn sleep(&self, duration: std::time::Duration) -> Pin> { + self.inner.sleep(duration) + } +} + +#[tokio::test] +pub async fn dns_lookup_error() { + super::reg(); + let builder = GLOBAL_RESOLVER_REGISTRY.get("dns").unwrap(); + let target = &"dns:///grpc.io:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let runtime = FakeRuntime { + inner: TokioRuntime {}, + dns: FakeDns { + latency: Duration::from_secs(0), + lookup_result: Err("test_error".to_string()), + }, + }; + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(runtime), + work_scheduler: work_scheduler.clone(), + }; + let mut resolver = builder.build(target, opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // An error endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.err().unwrap().contains("test_error"), true); +} + +#[tokio::test] +pub async fn dns_lookup_timeout() { + let target = &"dns:///grpc.io:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let runtime = FakeRuntime { + inner: TokioRuntime {}, + dns: FakeDns { + latency: Duration::from_secs(20), + lookup_result: Ok(Vec::new()), + }, + }; + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(runtime), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: super::get_min_resolution_interval(), + resolving_timeout: DEFAULT_TEST_SHORT_TIMEOUT, + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + + // An error endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.err().unwrap().contains("Timed out"), true); +} + +#[tokio::test] +pub async fn rate_limit() { + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: Duration::from_secs(20), + resolving_timeout: super::get_resolving_timeout(), + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + // Wait for schedule work to be called. + let event = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // A successful endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + + // Call resolve_now repeatedly, new updates should not be produced. + for _ in 0..5 { + resolver.resolve_now(); + tokio::select! { + _ = work_rx.recv() => { + panic!("Received unexpected work request from resolver: {:?}", event); + } + _ = tokio::time::sleep(DEFAULT_TEST_SHORT_TIMEOUT) => { + println!("No work requested from resolver."); + } + }; + } +} + +#[tokio::test] +pub async fn re_resolution_after_success() { + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: Duration::from_millis(1), + resolving_timeout: super::get_resolving_timeout(), + backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + // Wait for schedule work to be called. + let _ = work_rx.recv().await.unwrap(); + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Ok(()), + }; + resolver.work(&mut channel_controller); + // A successful endpoint update should be received. + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + + // Call resolve_now, a new update should be produced. + resolver.resolve_now(); + let _ = work_rx.recv().await.unwrap(); + resolver.work(&mut channel_controller); + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); +} + +#[tokio::test] +pub async fn backoff_on_error() { + let target = &"dns:///localhost:1234".parse().unwrap(); + let (work_tx, mut work_rx) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(WorkScheduler { + work_tx: work_tx.clone(), + }); + let opts = ResolverOptions { + authority: "ignored".to_string(), + runtime: Arc::new(TokioRuntime {}), + work_scheduler: work_scheduler.clone(), + }; + let dns_opts = super::DnsOptions { + min_resolution_interval: Duration::from_millis(1), + resolving_timeout: super::get_resolving_timeout(), + // Speed up the backoffs to make the test run faster. + backoff_config: BackoffConfig { + base_delay: Duration::from_millis(1), + multiplier: 1.0, + jitter: 0.0, + max_delay: Duration::from_millis(1), + }, + }; + let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + + let (update_tx, mut update_rx) = mpsc::unbounded_channel(); + let mut channel_controller = FakeChannelController { + update_tx, + update_result: Err("test_error".to_string()), + }; + + // As the channel returned an error to the resolver, the resolver will + // backoff and re-attempt resolution. + for _ in 0..5 { + let _ = work_rx.recv().await.unwrap(); + resolver.work(&mut channel_controller); + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + } + + // This time the channel accepts the resolver update. + channel_controller.update_result = Ok(()); + let _ = work_rx.recv().await.unwrap(); + resolver.work(&mut channel_controller); + let update = update_rx.recv().await.unwrap(); + assert_eq!(update.endpoints.unwrap().len() > 1, true); + + // Since the channel controller returns Ok(), the resolver will stop + // producing more updates. + tokio::select! { + _ = work_rx.recv() => { + panic!("Received unexpected work request from resolver."); + } + _ = tokio::time::sleep(DEFAULT_TEST_SHORT_TIMEOUT) => { + println!("No event received from resolver."); + } + }; +} diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 05c5cf795..1a24c371f 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -23,95 +23,214 @@ //! a service. use core::fmt; +use super::service_config::ServiceConfig; +use crate::{attributes::Attributes, rt}; use std::{ - error::Error, fmt::{Display, Formatter}, hash::Hash, + str::FromStr, sync::Arc, }; -use tokio::sync::Notify; -use tonic::async_trait; -use url::Url; +mod backoff; +mod dns; +mod registry; +pub use registry::GLOBAL_RESOLVER_REGISTRY; -use crate::attributes::Attributes; +/// Target represents a target for gRPC, as specified in: +/// https://github.com/grpc/grpc/blob/master/doc/naming.md. +/// It is parsed from the target string that gets passed during channel creation +/// by the user. gRPC passes it to the resolver and the balancer. +/// +/// If the target follows the naming spec, and the parsed scheme is registered +/// with gRPC, we will parse the target string according to the spec. If the +/// target does not contain a scheme or if the parsed scheme is not registered +/// (i.e. no corresponding resolver available to resolve the endpoint), we will +/// apply the default scheme, and will attempt to reparse it. +#[derive(Debug, Clone)] +pub struct Target { + url: url::Url, +} -use super::service_config::ServiceConfig; +impl FromStr for Target { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.parse::() { + Ok(url) => Ok(Target { url }), + Err(err) => Err(err.to_string()), + } + } +} + +impl Target { + pub fn scheme(&self) -> &str { + self.url.scheme() + } + + /// The host part of the authority. + pub fn authority_host(&self) -> &str { + self.url.host_str().unwrap_or("") + } + + /// The port part of the authority. + pub fn authority_port(&self) -> Option { + self.url.port() + } + + /// Returns either host:port or host depending on the existence of the port + /// in the authority. + pub fn authority_host_port(&self) -> String { + let host = self.authority_host(); + let port = self.authority_port(); + if let Some(port) = port { + format!("{}:{}", host, port) + } else { + host.to_owned() + } + } + + /// Return the path for this target URL, as a percent-encoded ASCII string. + pub fn path(&self) -> &str { + self.url.path() + } +} + +impl Display for Target { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}//{}/{}", + self.scheme(), + self.authority_host_port(), + self.path() + ) + } +} /// A name resolver factory that produces Resolver instances used by the channel /// to resolve network addresses for the target URI. pub trait ResolverBuilder: Send + Sync { - /// Builds and returns a new name resolver instance. + /// Builds a name resolver instance. /// /// Note that build must not fail. Instead, an erroring Resolver may be /// returned that calls ChannelController.update() with an Err value. - fn build( - &self, - target: Url, - resolve_now: Arc, - options: ResolverOptions, - ) -> Box; + fn build(&self, target: &Target, options: ResolverOptions) -> Box; /// Reports the URI scheme handled by this name resolver. fn scheme(&self) -> &'static str; - /// Returns the default authority for a channel using this name resolver and - /// target. This is typically the same as the service's name. By default, - /// the default_authority method automatically returns the path portion of - /// the target URI, with the leading prefix removed. - fn default_authority(&self, target: &Url) -> String { + /// Returns the default authority for a channel using this name resolver + /// and target. This is typically the same as the service's name. By + /// default, the default_authority method automatically returns the path + /// portion of the target URI, with the leading prefix removed. + fn default_authority(&self, target: &Target) -> String { let path = target.path(); path.strip_prefix("/").unwrap_or(path).to_string() } + + /// Returns a bool indicating whether the input uri is valid to create a + /// resolver. + fn is_valid_uri(&self, uri: &Target) -> bool; } /// A collection of data configured on the channel that is constructing this /// name resolver. -#[derive(Debug, Default)] #[non_exhaustive] pub struct ResolverOptions { /// The authority that will be used for the channel by default. This /// contains either the result of the default_authority method of this /// ResolverBuilder, or another string if the channel was configured to /// override the default. - authority: String, + pub authority: String, + + /// The runtime which provides utilities to do async work. + pub runtime: Arc, + + /// A hook into the channel's work scheduler that allows the Resolver to + /// request the ability to perform operations on the ChannelController. + pub work_scheduler: Arc, } -#[async_trait] -/// A collection of operations a Resolver may perform on the channel which -/// constructed it. -pub trait ChannelController: Send + Sync { - /// Parses the provided JSON service config. - fn parse_config(&self, config: &str) -> Result>; // TODO +/// Used to asynchronously request a call into the Resolver's work method. +pub trait WorkScheduler: Send + Sync { + // Schedules a call into the Resolver's work method. If there is already a + // pending work call that has not yet started, this may not schedule another + // call. + fn schedule_work(&self); +} + +/// Resolver watches for the updates on the specified target. +/// Updates include address updates and service config updates. +pub trait Resolver: Send { + /// Asks the resolver to obtain an updated resolver result, if + /// applicable. + /// + /// This is useful for pull-based implementations to decide when to + /// re-resolve. However, the implementation is not required to + /// re-resolve immediately upon receiving this call; it may instead + /// elect to delay based on some configured minimum time between + /// queries, to avoid hammering the name service with queries. + /// + /// For push-based implementations, this may be a no-op. + fn resolve_now(&mut self); + /// Called serially by the channel to do work after the work scheduler's + /// schedule_work method is called. + fn work(&mut self, channel_controller: &mut dyn ChannelController); +} + +/// The `ChannelController` trait provides the resolver with functionality +/// to interact with the channel. +pub trait ChannelController: Send + Sync { /// Notifies the channel about the current state of the name resolver. If /// an error value is returned, the name resolver should attempt to /// re-resolve, if possible. The resolver is responsible for applying an /// appropriate backoff mechanism to avoid overloading the system or the /// remote resolver. - async fn update(&self, update: ResolverUpdate) -> Result<(), Box>; -} + fn update(&mut self, update: ResolverUpdate) -> Result<(), String>; -/// A name resolver update expresses the current state of the resolver. -pub enum ResolverUpdate { - /// Indicates the name resolver encountered an error. - Err(Box), - /// Indicates the name resolver produced a valid result. - Data(ResolverData), + /// Parses the provided JSON service config and returns an instance of a + /// ParsedServiceConfig. + fn parse_service_config(&self, config: &str) -> Result; } -/// Data provided by the name resolver to the channel. -#[derive(Debug, Default)] +#[derive(Clone, Debug)] #[non_exhaustive] -pub struct ResolverData { +/// ResolverUpdate contains the current Resolver state relevant to the +/// channel. +pub struct ResolverUpdate { + /// Attributes contains arbitrary data about the resolver intended for + /// consumption by the load balancing policy. + pub attributes: Arc, + /// A list of endpoints which each identify a logical host serving the /// service indicated by the target URI. - pub endpoints: Vec, + pub endpoints: Result, String>, + /// The service config which the client should use for communicating with - /// the service. - pub service_config: Option, - // Optional data which may be used by the LB Policy or channel. - pub attributes: Attributes, + /// the service. If it is None, it indicates no service config is present or + /// the resolver does not provide service configs. + pub service_config: Result, String>, + + /// An optional human-readable note describing context about the + /// resolution, to be passed along to the LB policy for inclusion in + /// RPC failure status messages in cases where neither endpoints nor + /// service_config has a non-OK status. For example, a resolver that + /// returns an empty endpoint list but a valid service config may set + /// to this to something like "no DNS entries found for ". + pub resolution_note: Option, +} + +impl Default for ResolverUpdate { + fn default() -> Self { + ResolverUpdate { + service_config: Ok(None), + attributes: Arc::default(), + endpoints: Ok(Vec::default()), + resolution_note: None, + } + } } /// An Endpoint is an address or a collection of addresses which reference one @@ -120,9 +239,28 @@ pub struct ResolverData { #[derive(Debug, Default, Clone)] #[non_exhaustive] pub struct Endpoint { - /// The list of addresses used to connect to the server. + /// Addresses contains a list of addresses used to access this endpoint. pub addresses: Vec
, - /// Optional data which may be used by the LB policy or channel. + + /// Attributes contains arbitrary data about this endpoint intended for + /// consumption by the LB policy. + pub attributes: Attributes, +} + +/// An Address is an identifier that indicates how to connect to a server. +#[non_exhaustive] +#[derive(Debug, Clone, Default)] +pub struct Address { + /// The network type is used to identify what kind of transport to create + /// when connecting to this address. Typically TCP_IP_ADDRESS_TYPE. + pub network_type: &'static str, + + /// The address itself is passed to the transport in order to create a + /// connection to it. + pub address: String, + + /// Attributes contains arbitrary data about this address intended for + /// consumption by the subchannel. pub attributes: Attributes, } @@ -140,20 +278,6 @@ impl Hash for Endpoint { } } -/// An Address is an identifier that indicates how to connect to a server. -#[derive(Debug, Default, Clone)] -#[non_exhaustive] -pub struct Address { - /// The address type is used to identify what kind of transport to create - /// when connecting to this address. Typically TCP_IP_ADDRESS_TYPE. - pub address_type: String, // TODO: &'static str? - /// The address itself is passed to the transport in order to create a - /// connection to it. - pub address: String, - // Optional data which the transport may use for the connection. - pub attributes: Attributes, -} - impl Eq for Address {} impl PartialEq for Address { @@ -170,20 +294,86 @@ impl Hash for Address { impl Display for Address { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.address_type, self.address) + write!(f, "{}:{}", self.network_type, self.address) } } /// Indicates the address is an IPv4 or IPv6 address that should be connected to /// via TCP/IP. -pub static TCP_IP_ADDRESS_TYPE: &str = "tcp"; - -/// A name resolver instance. -#[async_trait] -pub trait Resolver: Send + Sync { - /// The entry point of the resolver. Will only be called once by the - /// channel. Should not return unless the resolver never will need to - /// update its state. The future will be dropped when the channel shuts - /// down or enters idle mode. - async fn run(&mut self, channel_controller: Box); +pub static TCP_IP_NETWORK_TYPE: &str = "tcp"; + +// A resolver that returns the same result every time it's work method is called. +// It can be used to return an error to the channel when a resolver fails to +// build. +struct NopResolver { + pub update: ResolverUpdate, +} + +impl Resolver for NopResolver { + fn resolve_now(&mut self) {} + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + let _ = channel_controller.update(self.update.clone()); + } +} + +#[cfg(test)] +mod test { + use super::Target; + + #[test] + pub fn parse_target() { + #[derive(Default)] + struct TestCase { + input: &'static str, + want_scheme: &'static str, + want_host: &'static str, + want_port: Option, + want_host_port: &'static str, + want_path: &'static str, + } + let test_cases = vec![ + TestCase { + input: "dns:///grpc.io", + want_scheme: "dns", + want_host_port: "", + want_host: "", + want_port: None, + want_path: "/grpc.io", + }, + TestCase { + input: "dns://8.8.8.8:53/grpc.io/docs", + want_scheme: "dns", + want_host_port: "8.8.8.8:53", + want_host: "8.8.8.8", + want_port: Some(53), + want_path: "/grpc.io/docs", + }, + TestCase { + input: "unix:path/to/file", + want_scheme: "unix", + want_host_port: "", + want_host: "", + want_port: None, + want_path: "path/to/file", + }, + TestCase { + input: "unix:///run/containerd/containerd.sock", + want_scheme: "unix", + want_host_port: "", + want_host: "", + want_port: None, + want_path: "/run/containerd/containerd.sock", + }, + ]; + + for tc in test_cases { + let target: Target = tc.input.parse().unwrap(); + assert_eq!(target.scheme(), tc.want_scheme); + assert_eq!(target.authority_host(), tc.want_host); + assert_eq!(target.authority_port(), tc.want_port); + assert_eq!(target.authority_host_port(), tc.want_host_port); + assert_eq!(target.path(), tc.want_path); + } + } } diff --git a/grpc/src/client/name_resolution/registry.rs b/grpc/src/client/name_resolution/registry.rs new file mode 100644 index 000000000..26ecc93b1 --- /dev/null +++ b/grpc/src/client/name_resolution/registry.rs @@ -0,0 +1,70 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use super::ResolverBuilder; + +/// A registry to store and retrieve name resolvers. Resolvers are indexed by +/// the URI scheme they are intended to handle. +#[derive(Default)] +pub struct ResolverRegistry { + m: Arc>>>, +} + +impl ResolverRegistry { + /// Construct an empty name resolver registry. + fn new() -> Self { + Self { m: Arc::default() } + } + + /// Add a name resolver into the registry. builder.scheme() will + // be used as the scheme registered with this builder. If multiple + // resolvers are registered with the same name, the one registered last + // will take effect. Panics if the given scheme contains uppercase + // characters. + pub fn add_builder(&self, builder: Box) { + let scheme = builder.scheme(); + if scheme.chars().any(|c| c.is_ascii_uppercase()) { + panic!("Scheme must not contain uppercase characters: {}", scheme); + } + self.m + .lock() + .unwrap() + .insert(scheme.to_string(), Arc::from(builder)); + } + + /// Returns the resolver builder registered for the given scheme, if any. + /// + /// The provided scheme is case-insensitive; any uppercase characters + /// will be converted to lowercase before lookup. + pub fn get(&self, scheme: &str) -> Option> { + self.m + .lock() + .unwrap() + .get(&scheme.to_lowercase()) + .map(|b| b.clone()) + } +} + +/// Global registry for resolver builders. +pub static GLOBAL_RESOLVER_REGISTRY: std::sync::LazyLock = + std::sync::LazyLock::new(ResolverRegistry::new); diff --git a/grpc/src/client/service_config.rs b/grpc/src/client/service_config.rs index 3639c03e5..1d5ad153a 100644 --- a/grpc/src/client/service_config.rs +++ b/grpc/src/client/service_config.rs @@ -18,5 +18,5 @@ /// An in-memory representation of a service config, usually provided to gRPC as /// a JSON object. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub(crate) struct ServiceConfig; diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index 244a7a904..a63f4d323 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -27,6 +27,7 @@ #![allow(dead_code)] pub mod client; +mod rt; pub mod service; pub(crate) mod attributes; diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs new file mode 100644 index 000000000..e7b0405a1 --- /dev/null +++ b/grpc/src/rt/mod.rs @@ -0,0 +1,69 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{future::Future, pin::Pin}; + +pub mod tokio; + +/// An abstraction over an asynchronous runtime. +/// +/// The `Runtime` trait defines the core functionality required for +/// executing asynchronous tasks, creating DNS resolvers, and performing +/// time-based operations such as sleeping. It provides a uniform interface +/// that can be implemented for various async runtimes, enabling pluggable +/// and testable infrastructure. +pub trait Runtime: Send + Sync { + /// Spawns the given asynchronous task to run in the background. + fn spawn( + &self, + task: Pin + Send + 'static>>, + ) -> Box; + + /// Creates and returns an instance of a DNSResolver, optionally + /// configured by the ResolverOptions struct. This method may return an + /// error if it fails to create the DNSResolver. + fn get_dns_resolver(&self, opts: ResolverOptions) -> Result, String>; + + /// Returns a future that completes after the specified duration. + fn sleep(&self, duration: std::time::Duration) -> Pin>; +} + +/// A future that resolves after a specified duration. +pub trait Sleep: Send + Sync + Future {} + +pub trait TaskHandle: Send + Sync { + /// Abort the associated task. + fn abort(&self); +} + +/// A trait for asynchronous DNS resolution. +#[tonic::async_trait] +pub trait DnsResolver: Send + Sync { + /// Resolve an address + async fn lookup_host_name(&self, name: &str) -> Result, String>; + /// Perform a TXT record lookup. If a txt record contains multiple strings, + /// they are concatenated. + async fn lookup_txt(&self, name: &str) -> Result, String>; +} + +#[derive(Default)] +pub struct ResolverOptions { + /// The address of the DNS server in "IP:port" format. If None, the + /// system's default DNS server will be used. + pub server_addr: Option, +} diff --git a/grpc/src/rt/tokio/hickory_resolver.rs b/grpc/src/rt/tokio/hickory_resolver.rs new file mode 100644 index 000000000..13b20b010 --- /dev/null +++ b/grpc/src/rt/tokio/hickory_resolver.rs @@ -0,0 +1,234 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}; + +/// A DNS resolver that uses hickory with the tokio runtime. This supports txt +/// lookups in addition to A and AAAA record lookups. It also supports using +/// custom DNS servers. +pub struct DnsResolver { + resolver: hickory_resolver::TokioResolver, +} + +#[tonic::async_trait] +impl super::DnsResolver for DnsResolver { + async fn lookup_host_name(&self, name: &str) -> Result, String> { + let response = self + .resolver + .lookup_ip(name) + .await + .map_err(|err| err.to_string())?; + Ok(response.iter().collect()) + } + + async fn lookup_txt(&self, name: &str) -> Result, String> { + let response: Vec<_> = self + .resolver + .txt_lookup(name) + .await + .map_err(|err| err.to_string())? + .iter() + .map(|txt_record| { + txt_record + .iter() + .map(|bytes| String::from_utf8_lossy(bytes).into_owned()) + .collect::>() + .join("") + }) + .collect(); + Ok(response) + } +} + +impl DnsResolver { + pub fn new(opts: super::ResolverOptions) -> Result { + let builder = if let Some(server_addr) = opts.server_addr { + let provider = hickory_resolver::name_server::TokioConnectionProvider::default(); + let name_servers = NameServerConfigGroup::from_ips_clear( + &[server_addr.ip()], + server_addr.port(), + true, + ); + let config = ResolverConfig::from_parts(None, vec![], name_servers); + hickory_resolver::TokioResolver::builder_with_config(config, provider) + } else { + hickory_resolver::TokioResolver::builder_tokio().map_err(|err| err.to_string())? + }; + let mut resolver_opts = ResolverOpts::default(); + resolver_opts.ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6; + Ok(DnsResolver { + resolver: builder.with_options(resolver_opts).build(), + }) + } +} + +#[cfg(test)] +mod tests { + use std::{ + net::{Ipv4Addr, SocketAddr}, + sync::Arc, + }; + + use hickory_resolver::Name; + use hickory_server::{ + authority::{Catalog, ZoneType}, + proto::rr::{ + rdata::{A, TXT}, + LowerName, RData, Record, + }, + store::in_memory::InMemoryAuthority, + ServerFuture, + }; + use tokio::{net::UdpSocket, sync::oneshot, task::JoinHandle}; + + use crate::rt::{tokio::TokioDefaultDnsResolver, DnsResolver, ResolverOptions}; + + #[tokio::test] + async fn compare_hickory_and_default() { + let hickory_dns = super::DnsResolver::new(ResolverOptions::default()).unwrap(); + let mut ips_hickory = hickory_dns.lookup_host_name("localhost").await.unwrap(); + + let default_resolver = TokioDefaultDnsResolver::new(ResolverOptions::default()).unwrap(); + + let mut system_resolver_ips = default_resolver + .lookup_host_name("localhost") + .await + .unwrap(); + + // Hickory requests A and AAAA records in parallel, so the order of IPv4 + // and IPv6 addresses isn't deterministic. + ips_hickory.sort(); + system_resolver_ips.sort(); + assert_eq!( + ips_hickory, system_resolver_ips, + "both resolvers should produce same IPs for localhost" + ) + } + + #[tokio::test] + async fn resolve_txt() { + let records = vec![ + Record::from_rdata( + Name::from_ascii("test.local.").unwrap(), + 300, + RData::TXT(TXT::new(vec![ + "one".to_string(), + "two".to_string(), + "three".to_string(), + ])), + ), + Record::from_rdata( + Name::from_ascii("test.local.").unwrap(), + 300, + RData::TXT(TXT::new(vec![ + "abc".to_string(), + "def".to_string(), + "ghi".to_string(), + ])), + ), + ]; + + let dns = start_in_memory_dns_server("test.local.", records).await; + let opts = ResolverOptions { + server_addr: Some(dns.addr), + }; + let hickory_dns = super::DnsResolver::new(opts).unwrap(); + + let txt = hickory_dns.lookup_txt("test.local").await.unwrap(); + assert_eq!( + txt, + vec!["onetwothree".to_string(), "abcdefghi".to_string(),] + ); + dns.shutdown().await; + } + + #[tokio::test] + async fn custom_authority() { + let record = Record::from_rdata( + Name::from_ascii("test.local.").unwrap(), + 300, + RData::A(A(Ipv4Addr::new(1, 2, 3, 4))), + ); + let dns = start_in_memory_dns_server("test.local.", vec![record]).await; + let opts = ResolverOptions { + server_addr: Some(dns.addr), + }; + let hickory_dns = super::DnsResolver::new(opts).unwrap(); + let ips = hickory_dns.lookup_host_name("test.local").await.unwrap(); + assert_eq!(ips, vec![Ipv4Addr::new(1, 2, 3, 4)]); + dns.shutdown().await + } + + struct FakeDns { + tx: Option>, + join_handle: Option>, + addr: SocketAddr, + } + + impl FakeDns { + async fn shutdown(mut self) { + let tx = self.tx.take().unwrap(); + tx.send(()).unwrap(); + let handle = self.join_handle.take().unwrap(); + handle.await.unwrap(); + } + } + + /// Starts an in-memory DNS server with and adds the given records. Returns + /// a DNS server which should be shutdown after the test. It uses a random + /// port to bind since tests can run in parallel. The assigned port can be + /// read from the returned struct. + async fn start_in_memory_dns_server(host: &str, records: Vec) -> FakeDns { + // Create a simple A record for `test.local.` + let authority = + InMemoryAuthority::empty(Name::from_ascii(host).unwrap(), ZoneType::Primary, false); + + for record in records { + authority.upsert(record, 0).await; + } + + let mut catalog = Catalog::new(); + catalog.upsert( + LowerName::new(&Name::from_ascii(host).unwrap()), + vec![Arc::new(authority)], + ); + + let mut server = ServerFuture::new(catalog); + + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = socket.local_addr().unwrap(); + server.register_socket(socket); + + println!("DNS server running on {}", addr); + + let (tx, rx) = oneshot::channel::<()>(); + let server_task = tokio::spawn(async move { + tokio::select! { + _ = server.block_until_done() => {}, + _ = rx => { + server.shutdown_gracefully().await.unwrap(); + } + } + }); + FakeDns { + tx: Some(tx), + join_handle: Some(server_task), + addr, + } + } +} diff --git a/grpc/src/rt/tokio/mod.rs b/grpc/src/rt/tokio/mod.rs new file mode 100644 index 000000000..fbebd6f78 --- /dev/null +++ b/grpc/src/rt/tokio/mod.rs @@ -0,0 +1,127 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +use std::{future::Future, net::SocketAddr, pin::Pin}; + +use super::{DnsResolver, ResolverOptions, Runtime, Sleep, TaskHandle}; + +#[cfg(feature = "hickory_dns")] +mod hickory_resolver; + +/// A DNS resolver that uses tokio::net::lookup_host for resolution. It only +/// supports host lookups. +pub struct TokioDefaultDnsResolver {} + +#[tonic::async_trait] +impl DnsResolver for TokioDefaultDnsResolver { + async fn lookup_host_name(&self, name: &str) -> Result, String> { + let name_with_port = match name.parse::() { + Ok(ip) => SocketAddr::new(ip, 0).to_string(), + Err(_) => format!("{}:0", name), + }; + let ips = tokio::net::lookup_host(name_with_port) + .await + .map_err(|err| err.to_string())? + .map(|socket_addr| socket_addr.ip()) + .collect(); + Ok(ips) + } + + async fn lookup_txt(&self, _name: &str) -> Result, String> { + Err("TXT record lookup unavailable. Enable the optional 'hickory_dns' feature to enable service config lookups.".to_string()) + } +} + +pub struct TokioRuntime {} + +impl TaskHandle for tokio::task::JoinHandle<()> { + fn abort(&self) { + self.abort() + } +} + +impl Sleep for tokio::time::Sleep {} + +impl Runtime for TokioRuntime { + fn spawn( + &self, + task: Pin + Send + 'static>>, + ) -> Box { + Box::new(tokio::spawn(task)) + } + + fn get_dns_resolver(&self, opts: ResolverOptions) -> Result, String> { + #[cfg(feature = "hickory_dns")] + { + Ok(Box::new(hickory_resolver::DnsResolver::new(opts)?)) + } + #[cfg(not(feature = "hickory_dns"))] + { + Ok(Box::new(TokioDefaultDnsResolver::new(opts)?)) + } + } + + fn sleep(&self, duration: std::time::Duration) -> Pin> { + Box::pin(tokio::time::sleep(duration)) + } +} + +impl TokioDefaultDnsResolver { + pub fn new(opts: ResolverOptions) -> Result { + if opts.server_addr.is_some() { + return Err("Custom DNS server are not supported, enable optional feature 'hickory_dns' to enable support.".to_string()); + } + Ok(TokioDefaultDnsResolver {}) + } +} + +#[cfg(test)] +mod tests { + use super::{DnsResolver, ResolverOptions, Runtime, TokioDefaultDnsResolver, TokioRuntime}; + + #[tokio::test] + async fn lookup_hostname() { + let runtime = TokioRuntime {}; + + let dns = runtime + .get_dns_resolver(ResolverOptions::default()) + .unwrap(); + let ips = dns.lookup_host_name("localhost").await.unwrap(); + assert!( + !ips.is_empty(), + "Expect localhost to resolve to more than 1 IPs." + ) + } + + #[tokio::test] + async fn default_resolver_txt_fails() { + let default_resolver = TokioDefaultDnsResolver::new(ResolverOptions::default()).unwrap(); + + let txt = default_resolver.lookup_txt("google.com").await; + assert!(txt.is_err()) + } + + #[tokio::test] + async fn default_resolver_custom_authority() { + let opts = ResolverOptions { + server_addr: Some("8.8.8.8:53".parse().unwrap()), + }; + let default_resolver = TokioDefaultDnsResolver::new(opts); + assert!(default_resolver.is_err()) + } +} From 1f8e819327b66cbc37d812eaa736a17b7f61f114 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Fri, 23 May 2025 01:11:18 +0530 Subject: [PATCH 2/7] Remove unstable deps --- grpc/src/client/name_resolution/dns/mod.rs | 23 +++++++++++---------- grpc/src/client/name_resolution/dns/test.rs | 8 +++---- grpc/src/client/name_resolution/mod.rs | 2 +- grpc/src/client/name_resolution/registry.rs | 9 +++++--- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs index 14f6b0393..9b29f778c 100644 --- a/grpc/src/client/name_resolution/dns/mod.rs +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -21,11 +21,13 @@ use std::{ net::{IpAddr, SocketAddr}, - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, time::{Duration, SystemTime}, }; -use once_cell::sync::Lazy; use tokio::sync::mpsc::UnboundedSender; use url::Host; @@ -36,7 +38,7 @@ use crate::{ use super::{ backoff::{BackoffConfig, ExponentialBackoff, DEFAULT_EXPONENTIAL_CONFIG}, - Endpoint, Resolver, ResolverBuilder, GLOBAL_RESOLVER_REGISTRY, + global_registry, Endpoint, Resolver, ResolverBuilder, }; #[cfg(test)] @@ -51,31 +53,30 @@ const DEFAULT_DNS_PORT: u16 = 53; /// /// It is recommended to set this value at application startup. Avoid modifying /// this variable after initialization. -static RESOLVING_TIMEOUT: Lazy> = Lazy::new(|| Mutex::new(Duration::from_secs(30))); +static RESOLVING_TIMEOUT_MS: AtomicU64 = AtomicU64::new(30_000); // 30 seconds /// This is the minimum interval at which re-resolutions are allowed. This helps /// to prevent excessive re-resolution. -static MIN_RESOLUTION_INTERVAL: Lazy> = - Lazy::new(|| Mutex::new(Duration::from_secs(30))); +static MIN_RESOLUTION_INTERVAL_MS: AtomicU64 = AtomicU64::new(30_000); // 30 seconds pub fn get_resolving_timeout() -> Duration { - RESOLVING_TIMEOUT.lock().unwrap().clone() + Duration::from_millis(RESOLVING_TIMEOUT_MS.load(Ordering::Relaxed)) } pub fn set_resolving_timeout(duration: Duration) { - *RESOLVING_TIMEOUT.lock().unwrap() = duration; + RESOLVING_TIMEOUT_MS.store(duration.as_millis() as u64, Ordering::Relaxed); } pub fn get_min_resolution_interval() -> Duration { - MIN_RESOLUTION_INTERVAL.lock().unwrap().clone() + Duration::from_millis(MIN_RESOLUTION_INTERVAL_MS.load(Ordering::Relaxed)) } pub fn set_min_resolution_interval(duration: Duration) { - *MIN_RESOLUTION_INTERVAL.lock().unwrap() = duration; + MIN_RESOLUTION_INTERVAL_MS.store(duration.as_millis() as u64, Ordering::Relaxed); } pub fn reg() { - GLOBAL_RESOLVER_REGISTRY.add_builder(Box::new(Builder {})); + global_registry().add_builder(Box::new(Builder {})); } struct Builder {} diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs index 6ef1077ea..5fad96e4a 100644 --- a/grpc/src/client/name_resolution/dns/test.rs +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -25,7 +25,7 @@ use crate::{ self, backoff::{BackoffConfig, DEFAULT_EXPONENTIAL_CONFIG}, dns::{parse_endpoint_and_authority, HostPort}, - ResolverOptions, ResolverUpdate, Target, GLOBAL_RESOLVER_REGISTRY, + global_registry, ResolverOptions, ResolverUpdate, Target, }, rt::{self, tokio::TokioRuntime}, }; @@ -183,7 +183,7 @@ impl name_resolution::ChannelController for FakeChannelController { #[tokio::test] pub async fn dns_basic() { super::reg(); - let builder = GLOBAL_RESOLVER_REGISTRY.get("dns").unwrap(); + let builder = global_registry().get("dns").unwrap(); let target = &"dns:///localhost:1234".parse().unwrap(); let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(WorkScheduler { @@ -212,7 +212,7 @@ pub async fn dns_basic() { #[tokio::test] pub async fn invalid_target() { super::reg(); - let builder = GLOBAL_RESOLVER_REGISTRY.get("dns").unwrap(); + let builder = global_registry().get("dns").unwrap(); let target = &"dns:///:1234".parse().unwrap(); let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(WorkScheduler { @@ -288,7 +288,7 @@ impl rt::Runtime for FakeRuntime { #[tokio::test] pub async fn dns_lookup_error() { super::reg(); - let builder = GLOBAL_RESOLVER_REGISTRY.get("dns").unwrap(); + let builder = global_registry().get("dns").unwrap(); let target = &"dns:///grpc.io:1234".parse().unwrap(); let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(WorkScheduler { diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 1a24c371f..6687d757e 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -35,7 +35,7 @@ use std::{ mod backoff; mod dns; mod registry; -pub use registry::GLOBAL_RESOLVER_REGISTRY; +pub use registry::global_registry; /// Target represents a target for gRPC, as specified in: /// https://github.com/grpc/grpc/blob/master/doc/naming.md. diff --git a/grpc/src/client/name_resolution/registry.rs b/grpc/src/client/name_resolution/registry.rs index 26ecc93b1..fa4a42163 100644 --- a/grpc/src/client/name_resolution/registry.rs +++ b/grpc/src/client/name_resolution/registry.rs @@ -18,11 +18,13 @@ use std::{ collections::HashMap, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, OnceLock}, }; use super::ResolverBuilder; +static GLOBAL_RESOLVER_REGISTRY: OnceLock = OnceLock::new(); + /// A registry to store and retrieve name resolvers. Resolvers are indexed by /// the URI scheme they are intended to handle. #[derive(Default)] @@ -66,5 +68,6 @@ impl ResolverRegistry { } /// Global registry for resolver builders. -pub static GLOBAL_RESOLVER_REGISTRY: std::sync::LazyLock = - std::sync::LazyLock::new(ResolverRegistry::new); +pub fn global_registry() -> &'static ResolverRegistry { + GLOBAL_RESOLVER_REGISTRY.get_or_init(|| ResolverRegistry::new()) +} From d7724ffb3b884eb1c9d0fde360353e8acac3a98b Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 29 May 2025 17:06:23 +0530 Subject: [PATCH 3/7] let backoff creation fail --- grpc/src/client/name_resolution/backoff.rs | 77 ++++++++++++---------- grpc/src/client/name_resolution/dns/mod.rs | 3 +- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/grpc/src/client/name_resolution/backoff.rs b/grpc/src/client/name_resolution/backoff.rs index a0709ef8f..2cd11b359 100644 --- a/grpc/src/client/name_resolution/backoff.rs +++ b/grpc/src/client/name_resolution/backoff.rs @@ -55,21 +55,36 @@ pub const DEFAULT_EXPONENTIAL_CONFIG: BackoffConfig = BackoffConfig { max_delay: Duration::from_secs(120), }; -impl ExponentialBackoff { - pub fn new(mut config: BackoffConfig) -> Self { - // Adjust params to get them in valid ranges. +impl BackoffConfig { + fn validate(&self) -> Result<(), &'static str> { + // Valid that params are in valid ranges. // 0 <= base_dealy <= max_delay - config.base_delay = config.base_delay.min(config.max_delay); + if self.base_delay > self.max_delay { + Err("base_delay must be greater than max_delay")?; + } // 1 <= multiplier - config.multiplier = config.multiplier.max(1.0); + if self.multiplier < 1.0 { + Err("multiplier must be greater than 1.0")?; + } // 0 <= jitter <= 1 - config.jitter = config.jitter.max(0.0); - config.jitter = config.jitter.min(1.0); + if self.jitter < 0.0 { + Err("jitter must be greater than or equal to 0")?; + } + if self.jitter > 1.0 { + Err("jitter must be less than or equal to 1")? + } + Ok(()) + } +} + +impl ExponentialBackoff { + pub fn new(config: BackoffConfig) -> Result { + config.validate()?; let next_delay_secs = config.base_delay.as_secs_f64(); - ExponentialBackoff { + Ok(ExponentialBackoff { config, next_delay_secs: Mutex::new(next_delay_secs), - } + }) } pub fn reset(&self) { @@ -94,12 +109,20 @@ impl ExponentialBackoff { mod tests { use std::time::Duration; - use crate::client::name_resolution::backoff::{BackoffConfig, ExponentialBackoff}; + use crate::client::name_resolution::backoff::{ + BackoffConfig, ExponentialBackoff, DEFAULT_EXPONENTIAL_CONFIG, + }; // Epsilon for floating point comparisons if needed, though Duration // comparisons are often better. const EPSILON: f64 = 1e-9; + #[test] + fn default_config_is_valid() { + let result = ExponentialBackoff::new(DEFAULT_EXPONENTIAL_CONFIG.clone()); + assert_eq!(result.is_ok(), true); + } + #[test] fn base_less_than_max() { let config = BackoffConfig { @@ -108,7 +131,7 @@ mod tests { jitter: 0.0, max_delay: Duration::from_secs(100), }; - let backoff = ExponentialBackoff::new(config.clone()); + let backoff = ExponentialBackoff::new(config).unwrap(); assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); } @@ -120,8 +143,8 @@ mod tests { base_delay: Duration::from_secs(100), max_delay: Duration::from_secs(10), }; - let backoff = ExponentialBackoff::new(config.clone()); - assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); } #[test] @@ -132,10 +155,8 @@ mod tests { base_delay: Duration::from_secs(10), max_delay: Duration::from_secs(100), }; - let backoff = ExponentialBackoff::new(config.clone()); - // multiplier gets clipped to 1. - assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); - assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); } #[test] @@ -146,10 +167,8 @@ mod tests { base_delay: Duration::from_secs(10), max_delay: Duration::from_secs(100), }; - let backoff = ExponentialBackoff::new(config.clone()); - // jitter gets clipped to 0. - assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); - assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); } #[test] @@ -160,16 +179,8 @@ mod tests { base_delay: Duration::from_secs(10), max_delay: Duration::from_secs(100), }; - let backoff = ExponentialBackoff::new(config.clone()); - // jitter gets clipped to 1. - // 0 <= duration <= 20. - let duration = backoff.backoff_duration(); - assert_eq!(duration.lt(&Duration::from_secs(20)), true); - assert_eq!(duration.gt(&Duration::from_secs(0)), true); - - let duration = backoff.backoff_duration(); - assert_eq!(duration.lt(&Duration::from_secs(20)), true); - assert_eq!(duration.gt(&Duration::from_secs(0)), true); + let result = ExponentialBackoff::new(config); + assert_eq!(result.is_err(), true); } #[test] @@ -180,7 +191,7 @@ mod tests { base_delay: Duration::from_secs(1), max_delay: Duration::from_secs(15), }; - let backoff = ExponentialBackoff::new(config.clone()); + let backoff = ExponentialBackoff::new(config.clone()).unwrap(); assert_eq!(backoff.backoff_duration(), Duration::from_secs(1)); assert_eq!(backoff.backoff_duration(), Duration::from_secs(2)); assert_eq!(backoff.backoff_duration(), Duration::from_secs(4)); @@ -208,7 +219,7 @@ mod tests { base_delay: Duration::from_secs(1), max_delay: Duration::from_secs(15), }; - let backoff = ExponentialBackoff::new(config.clone()); + let backoff = ExponentialBackoff::new(config.clone()).unwrap(); // 0.8 <= duration <= 1.2. let duration = backoff.backoff_duration(); assert_eq!(duration.gt(&Duration::from_secs_f64(0.8 - EPSILON)), true); diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs index 9b29f778c..95d0277c1 100644 --- a/grpc/src/client/name_resolution/dns/mod.rs +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -123,7 +123,8 @@ impl DnsResolver { tokio::sync::mpsc::unbounded_channel::>(); let handle = options.runtime.clone().spawn(Box::pin(async move { - let backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone()); + let backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone()) + .expect("default exponential config must be valid"); let state = state_copy; let work_scheduler = options.work_scheduler; let mut update_error_rx = update_error_rx; From eeda1e6418253d0f851c3964f7d4f0e8a1e94053 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 29 May 2025 22:31:30 +0530 Subject: [PATCH 4/7] Address comments in dns implementation --- grpc/src/client/name_resolution/dns/mod.rs | 73 +++++++++++++--------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs index 95d0277c1..e4a72380d 100644 --- a/grpc/src/client/name_resolution/dns/mod.rs +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -28,7 +28,7 @@ use std::{ time::{Duration, SystemTime}, }; -use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::Notify; use url::Host; use crate::{ @@ -59,18 +59,35 @@ static RESOLVING_TIMEOUT_MS: AtomicU64 = AtomicU64::new(30_000); // 30 seconds /// to prevent excessive re-resolution. static MIN_RESOLUTION_INTERVAL_MS: AtomicU64 = AtomicU64::new(30_000); // 30 seconds -pub fn get_resolving_timeout() -> Duration { +fn get_resolving_timeout() -> Duration { Duration::from_millis(RESOLVING_TIMEOUT_MS.load(Ordering::Relaxed)) } +/// Sets the maximum duration for DNS resolution requests. +/// +/// This function affects the global timeout used by all channels using the DNS +/// name resolver scheme. +/// +/// It must be called only at application startup, before any gRPC calls are +/// made. +/// +/// The default value is 30 seconds. Setting the timeout too low may result in +/// premature timeouts during resolution, while setting it too high may lead to +/// unnecessary delays in service discovery. Choose a value appropriate for your +/// specific needs and network environment. pub fn set_resolving_timeout(duration: Duration) { RESOLVING_TIMEOUT_MS.store(duration.as_millis() as u64, Ordering::Relaxed); } -pub fn get_min_resolution_interval() -> Duration { +fn get_min_resolution_interval() -> Duration { Duration::from_millis(MIN_RESOLUTION_INTERVAL_MS.load(Ordering::Relaxed)) } +/// Sets the default minimum interval at which DNS re-resolutions are allowed. +/// This helps to prevent excessive re-resolution. +/// +/// It must be called only at application startup, before any gRPC calls are +/// made. pub fn set_min_resolution_interval(duration: Duration) { MIN_RESOLUTION_INTERVAL_MS.store(duration.as_millis() as u64, Ordering::Relaxed); } @@ -116,18 +133,19 @@ impl DnsResolver { }; let state = Arc::new(Mutex::new(InternalState { addrs: Ok(Vec::new()), + channel_response: None, })); let state_copy = state.clone(); - let (resolve_now_tx, mut resolve_now_rx) = tokio::sync::mpsc::unbounded_channel::<()>(); - let (update_error_tx, update_error_rx) = - tokio::sync::mpsc::unbounded_channel::>(); + let resolve_now_notify = Arc::new(Notify::new()); + let channel_updated_notify = Arc::new(Notify::new()); + let channel_updated_rx = channel_updated_notify.clone(); + let resolve_now_rx = resolve_now_notify.clone(); let handle = options.runtime.clone().spawn(Box::pin(async move { let backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone()) .expect("default exponential config must be valid"); let state = state_copy; let work_scheduler = options.work_scheduler; - let mut update_error_rx = update_error_rx; loop { let mut lookup_fut = dns.lookup_host_name(&host); let mut timeout_fut = options.runtime.sleep(dns_opts.resolving_timeout); @@ -156,12 +174,13 @@ impl DnsResolver { internal_state.addrs = addrs; } work_scheduler.schedule_work(); - let update_result = match update_error_rx.recv().await { - Some(res) => res, - None => return, - }; + channel_updated_rx.notified().await; + let channel_response: Option; + { + channel_response = state.lock().unwrap().channel_response.take(); + } let next_resoltion_time: SystemTime; - if update_result.is_err() { + if channel_response.is_some() { next_resoltion_time = SystemTime::now() .checked_add(backoff.backoff_duration()) .unwrap(); @@ -173,12 +192,11 @@ impl DnsResolver { next_resoltion_time = SystemTime::now() .checked_add(dns_opts.min_resolution_interval) .unwrap(); - _ = resolve_now_rx.recv().await; + _ = resolve_now_rx.notified().await; } // Wait till next resolution time. - let duration = match next_resoltion_time.duration_since(SystemTime::now()) { - Ok(d) => d, - Err(_) => continue, // Time has already passed. + let Ok(duration) = next_resoltion_time.duration_since(SystemTime::now()) else { + continue; // Time has already passed. }; options.runtime.sleep(duration).await; } @@ -187,8 +205,8 @@ impl DnsResolver { Box::new(DnsResolver { state, task_handle: handle, - resolve_now_requester: resolve_now_tx, - update_error_sender: update_error_tx, + resolve_now_notifier: resolve_now_notify, + channel_update_notifier: channel_updated_notify, }) } } @@ -224,27 +242,23 @@ impl ResolverBuilder for Builder { struct DnsResolver { state: Arc>, task_handle: Box, - resolve_now_requester: UnboundedSender<()>, - update_error_sender: UnboundedSender>, + resolve_now_notifier: Arc, + channel_update_notifier: Arc, } struct InternalState { addrs: Result, String>, + // Error from the latest call to channel_controller.update(). + channel_response: Option, } impl Resolver for DnsResolver { fn resolve_now(&mut self) { - _ = self.resolve_now_requester.send(()); + _ = self.resolve_now_notifier.notify_one(); } fn work(&mut self, channel_controller: &mut dyn super::ChannelController) { - let state = match self.state.lock() { - Err(_) => { - eprintln!("DNS resolver mutex poisoned, can't update channel"); - return; - } - Ok(s) => s, - }; + let mut state = self.state.lock().unwrap(); let endpoint_result = match &state.addrs { Ok(addrs) => { let endpoints: Vec<_> = addrs @@ -267,7 +281,8 @@ impl Resolver for DnsResolver { ..Default::default() }; let status = channel_controller.update(update); - _ = self.update_error_sender.send(status); + state.channel_response = status.err(); + _ = self.channel_update_notifier.notify_one(); } } From ec193f03ad01a3e565d87140acee91274015ef46 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 29 May 2025 23:48:09 +0530 Subject: [PATCH 5/7] Name resolution API changes --- grpc/src/client/name_resolution/mod.rs | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 6687d757e..5c808250c 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -100,7 +100,7 @@ impl Display for Target { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, - "{}//{}/{}", + "{}://{}{}", self.scheme(), self.authority_host_port(), self.path() @@ -166,17 +166,16 @@ pub trait Resolver: Send { /// Asks the resolver to obtain an updated resolver result, if /// applicable. /// - /// This is useful for pull-based implementations to decide when to - /// re-resolve. However, the implementation is not required to + /// This is useful for polling resolvers to decide when to re-resolve. + /// However, the implementation is not required to /// re-resolve immediately upon receiving this call; it may instead /// elect to delay based on some configured minimum time between /// queries, to avoid hammering the name service with queries. /// - /// For push-based implementations, this may be a no-op. + /// For watch based resolvers, this may be a no-op. fn resolve_now(&mut self); - /// Called serially by the channel to do work after the work scheduler's - /// schedule_work method is called. + /// Called serially by the channel to to allow access to ChannelController. fn work(&mut self, channel_controller: &mut dyn ChannelController); } @@ -225,10 +224,10 @@ pub struct ResolverUpdate { impl Default for ResolverUpdate { fn default() -> Self { ResolverUpdate { - service_config: Ok(None), - attributes: Arc::default(), - endpoints: Ok(Vec::default()), - resolution_note: None, + service_config: Ok(Default::default()), + attributes: Default::default(), + endpoints: Ok(Default::default()), + resolution_note: Default::default(), } } } @@ -331,6 +330,7 @@ mod test { want_port: Option, want_host_port: &'static str, want_path: &'static str, + want_str: &'static str, } let test_cases = vec![ TestCase { @@ -340,6 +340,7 @@ mod test { want_host: "", want_port: None, want_path: "/grpc.io", + want_str: "dns:///grpc.io", }, TestCase { input: "dns://8.8.8.8:53/grpc.io/docs", @@ -348,6 +349,7 @@ mod test { want_host: "8.8.8.8", want_port: Some(53), want_path: "/grpc.io/docs", + want_str: "dns://8.8.8.8:53/grpc.io/docs", }, TestCase { input: "unix:path/to/file", @@ -356,6 +358,7 @@ mod test { want_host: "", want_port: None, want_path: "path/to/file", + want_str: "unix://path/to/file", }, TestCase { input: "unix:///run/containerd/containerd.sock", @@ -364,6 +367,7 @@ mod test { want_host: "", want_port: None, want_path: "/run/containerd/containerd.sock", + want_str: "unix:///run/containerd/containerd.sock", }, ]; @@ -374,6 +378,7 @@ mod test { assert_eq!(target.authority_port(), tc.want_port); assert_eq!(target.authority_host_port(), tc.want_host_port); assert_eq!(target.path(), tc.want_path); + assert_eq!(&target.to_string(), tc.want_str); } } } From a73c5c01f7a6febcf6638c182d4de75c17ff656c Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Wed, 4 Jun 2025 00:50:26 +0530 Subject: [PATCH 6/7] remove mutex from backoff --- grpc/src/client/name_resolution/backoff.rs | 27 +++++++++++----------- grpc/src/client/name_resolution/dns/mod.rs | 2 +- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/grpc/src/client/name_resolution/backoff.rs b/grpc/src/client/name_resolution/backoff.rs index 2cd11b359..894155f80 100644 --- a/grpc/src/client/name_resolution/backoff.rs +++ b/grpc/src/client/name_resolution/backoff.rs @@ -17,7 +17,7 @@ */ use rand::Rng; -use std::{sync::Mutex, time::Duration}; +use std::time::Duration; #[derive(Clone)] pub struct BackoffConfig { @@ -40,7 +40,7 @@ pub struct ExponentialBackoff { /// The delay for the next retry, without the random jitter. Store as f64 /// to avoid rounding errors. - next_delay_secs: Mutex, + next_delay_secs: f64, } /// This is a backoff configuration with the default values specified @@ -83,24 +83,23 @@ impl ExponentialBackoff { let next_delay_secs = config.base_delay.as_secs_f64(); Ok(ExponentialBackoff { config, - next_delay_secs: Mutex::new(next_delay_secs), + next_delay_secs: next_delay_secs, }) } - pub fn reset(&self) { - let mut next_delay = self.next_delay_secs.lock().unwrap(); - *next_delay = self.config.base_delay.as_secs_f64(); + pub fn reset(&mut self) { + self.next_delay_secs = self.config.base_delay.as_secs_f64(); } - pub fn backoff_duration(&self) -> Duration { - let mut next_delay = self.next_delay_secs.lock().unwrap(); + pub fn backoff_duration(&mut self) -> Duration { + let next_delay = self.next_delay_secs; let cur_delay = - *next_delay * (1.0 + self.config.jitter * rand::thread_rng().gen_range(-1.0..1.0)); - *next_delay = self + next_delay * (1.0 + self.config.jitter * rand::thread_rng().gen_range(-1.0..1.0)); + self.next_delay_secs = self .config .max_delay .as_secs_f64() - .min(*next_delay * self.config.multiplier); + .min(next_delay * self.config.multiplier); Duration::from_secs_f64(cur_delay) } } @@ -131,7 +130,7 @@ mod tests { jitter: 0.0, max_delay: Duration::from_secs(100), }; - let backoff = ExponentialBackoff::new(config).unwrap(); + let mut backoff = ExponentialBackoff::new(config).unwrap(); assert_eq!(backoff.backoff_duration(), Duration::from_secs(10)); } @@ -191,7 +190,7 @@ mod tests { base_delay: Duration::from_secs(1), max_delay: Duration::from_secs(15), }; - let backoff = ExponentialBackoff::new(config.clone()).unwrap(); + let mut backoff = ExponentialBackoff::new(config.clone()).unwrap(); assert_eq!(backoff.backoff_duration(), Duration::from_secs(1)); assert_eq!(backoff.backoff_duration(), Duration::from_secs(2)); assert_eq!(backoff.backoff_duration(), Duration::from_secs(4)); @@ -219,7 +218,7 @@ mod tests { base_delay: Duration::from_secs(1), max_delay: Duration::from_secs(15), }; - let backoff = ExponentialBackoff::new(config.clone()).unwrap(); + let mut backoff = ExponentialBackoff::new(config.clone()).unwrap(); // 0.8 <= duration <= 1.2. let duration = backoff.backoff_duration(); assert_eq!(duration.gt(&Duration::from_secs_f64(0.8 - EPSILON)), true); diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs index e4a72380d..76abe5be0 100644 --- a/grpc/src/client/name_resolution/dns/mod.rs +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -142,7 +142,7 @@ impl DnsResolver { let resolve_now_rx = resolve_now_notify.clone(); let handle = options.runtime.clone().spawn(Box::pin(async move { - let backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone()) + let mut backoff = ExponentialBackoff::new(dns_opts.backoff_config.clone()) .expect("default exponential config must be valid"); let state = state_copy; let work_scheduler = options.work_scheduler; From 0a49114d33a8816d74bc34ba71e6d95e0f53d7d0 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Wed, 4 Jun 2025 14:52:12 +0530 Subject: [PATCH 7/7] Address review comments --- grpc/Cargo.toml | 1 - grpc/src/client/name_resolution/backoff.rs | 4 +- grpc/src/client/name_resolution/dns/mod.rs | 93 ++++++++++----------- grpc/src/client/name_resolution/dns/test.rs | 46 ++++++---- grpc/src/client/name_resolution/mod.rs | 36 ++++---- grpc/src/client/name_resolution/registry.rs | 8 +- 6 files changed, 101 insertions(+), 87 deletions(-) diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index 017e6e37f..046addfbd 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -10,7 +10,6 @@ url = "2.5.0" tokio = { version = "1.37.0", features = ["sync", "rt", "net", "time", "macros"] } tonic = { version = "0.13.0", path = "../tonic", default-features = false, features = ["codegen"] } futures-core = "0.3.31" -once_cell = "1.19.0" hickory-resolver = { version = "0.25.1", optional = true } rand = "0.8.5" diff --git a/grpc/src/client/name_resolution/backoff.rs b/grpc/src/client/name_resolution/backoff.rs index 894155f80..14ec56053 100644 --- a/grpc/src/client/name_resolution/backoff.rs +++ b/grpc/src/client/name_resolution/backoff.rs @@ -57,7 +57,7 @@ pub const DEFAULT_EXPONENTIAL_CONFIG: BackoffConfig = BackoffConfig { impl BackoffConfig { fn validate(&self) -> Result<(), &'static str> { - // Valid that params are in valid ranges. + // Check that the arguments are in valid ranges. // 0 <= base_dealy <= max_delay if self.base_delay > self.max_delay { Err("base_delay must be greater than max_delay")?; @@ -83,7 +83,7 @@ impl ExponentialBackoff { let next_delay_secs = config.base_delay.as_secs_f64(); Ok(ExponentialBackoff { config, - next_delay_secs: next_delay_secs, + next_delay_secs, }) } diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs index 76abe5be0..4bfd3eec5 100644 --- a/grpc/src/client/name_resolution/dns/mod.rs +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -31,14 +31,11 @@ use std::{ use tokio::sync::Notify; use url::Host; -use crate::{ - client::name_resolution::{Address, NopResolver, ResolverUpdate, TCP_IP_NETWORK_TYPE}, - rt, -}; +use crate::rt; use super::{ backoff::{BackoffConfig, ExponentialBackoff, DEFAULT_EXPONENTIAL_CONFIG}, - global_registry, Endpoint, Resolver, ResolverBuilder, + Address, Endpoint, NopResolver, Resolver, ResolverOptions, ResolverUpdate, TCP_IP_NETWORK_TYPE, }; #[cfg(test)] @@ -93,7 +90,7 @@ pub fn set_min_resolution_interval(duration: Duration) { } pub fn reg() { - global_registry().add_builder(Box::new(Builder {})); + super::global_registry().add_builder(Box::new(Builder {})); } struct Builder {} @@ -102,35 +99,16 @@ struct DnsOptions { min_resolution_interval: Duration, resolving_timeout: Duration, backoff_config: BackoffConfig, + host: String, + port: u16, } impl DnsResolver { fn new( - target: &super::Target, - options: super::ResolverOptions, + dns_client: Box, + options: ResolverOptions, dns_opts: DnsOptions, - ) -> Box { - let parsed = match parse_endpoint_and_authority(target) { - Ok(res) => res, - Err(err) => return nop_resolver_for_err(err.to_string(), options), - }; - let endpoint = parsed.endpoint; - let host = match endpoint.host { - Host::Domain(d) => d, - Host::Ipv4(ipv4) => { - return nop_resolver_for_ip(IpAddr::V4(ipv4), endpoint.port, options) - } - Host::Ipv6(ipv6) => { - return nop_resolver_for_ip(IpAddr::V6(ipv6), endpoint.port, options) - } - }; - let authority = parsed.authority; - let dns = match options.runtime.get_dns_resolver(rt::ResolverOptions { - server_addr: authority, - }) { - Ok(dns) => dns, - Err(err) => return nop_resolver_for_err(err.to_string(), options), - }; + ) -> Self { let state = Arc::new(Mutex::new(InternalState { addrs: Ok(Vec::new()), channel_response: None, @@ -147,7 +125,7 @@ impl DnsResolver { let state = state_copy; let work_scheduler = options.work_scheduler; loop { - let mut lookup_fut = dns.lookup_host_name(&host); + let mut lookup_fut = dns_client.lookup_host_name(&dns_opts.host); let mut timeout_fut = options.runtime.sleep(dns_opts.resolving_timeout); let addrs = tokio::select! { result = &mut lookup_fut => { @@ -155,7 +133,7 @@ impl DnsResolver { Ok(ips) => { let addrs = ips .into_iter() - .map(|ip| SocketAddr::new(ip, endpoint.port)) + .map(|ip| SocketAddr::new(ip, dns_opts.port)) .collect(); Ok(addrs) } @@ -202,27 +180,46 @@ impl DnsResolver { } })); - Box::new(DnsResolver { + Self { state, task_handle: handle, resolve_now_notifier: resolve_now_notify, channel_update_notifier: channel_updated_notify, - }) + } } } -impl ResolverBuilder for Builder { - fn build( - &self, - target: &super::Target, - options: super::ResolverOptions, - ) -> Box { +impl super::ResolverBuilder for Builder { + fn build(&self, target: &super::Target, options: ResolverOptions) -> Box { + let parsed = match parse_endpoint_and_authority(target) { + Ok(res) => res, + Err(err) => return nop_resolver_for_err(err.to_string(), options), + }; + let endpoint = parsed.endpoint; + let host = match endpoint.host { + Host::Domain(d) => d, + Host::Ipv4(ipv4) => { + return nop_resolver_for_ip(IpAddr::V4(ipv4), endpoint.port, options) + } + Host::Ipv6(ipv6) => { + return nop_resolver_for_ip(IpAddr::V6(ipv6), endpoint.port, options) + } + }; + let authority = parsed.authority; + let dns_client = match options.runtime.get_dns_resolver(rt::ResolverOptions { + server_addr: authority, + }) { + Ok(dns) => dns, + Err(err) => return nop_resolver_for_err(err.to_string(), options), + }; let dns_opts = DnsOptions { min_resolution_interval: get_min_resolution_interval(), resolving_timeout: get_resolving_timeout(), backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + host, + port: endpoint.port, }; - DnsResolver::new(target, options, dns_opts) + Box::new(DnsResolver::new(dns_client, options, dns_opts)) } fn scheme(&self) -> &'static str { @@ -254,7 +251,7 @@ struct InternalState { impl Resolver for DnsResolver { fn resolve_now(&mut self) { - _ = self.resolve_now_notifier.notify_one(); + self.resolve_now_notifier.notify_one(); } fn work(&mut self, channel_controller: &mut dyn super::ChannelController) { @@ -282,7 +279,7 @@ impl Resolver for DnsResolver { }; let status = channel_controller.update(update); state.channel_response = status.err(); - _ = self.channel_update_notifier.notify_one(); + self.channel_update_notifier.notify_one(); } } @@ -344,7 +341,7 @@ fn parse_endpoint_and_authority(target: &super::Target) -> Result is returned. +/// Ok(None) is returned. fn parse_host_port(host_and_port: &str, default_port: u16) -> Result, String> { // We need to use the https scheme otherwise url::Url::parse doesn't convert // IP addresses to Host::Ipv4 or Host::Ipv6 if they could represent valid @@ -365,11 +362,7 @@ fn parse_host_port(host_and_port: &str, default_port: u16) -> Result Box { +fn nop_resolver_for_ip(ip: IpAddr, port: u16, options: ResolverOptions) -> Box { options.work_scheduler.schedule_work(); Box::new(NopResolver { update: ResolverUpdate { @@ -386,7 +379,7 @@ fn nop_resolver_for_ip( }) } -fn nop_resolver_for_err(err: String, options: super::ResolverOptions) -> Box { +fn nop_resolver_for_err(err: String, options: ResolverOptions) -> Box { options.work_scheduler.schedule_work(); Box::new(NopResolver { update: ResolverUpdate { diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs index 5fad96e4a..ddf3b6104 100644 --- a/grpc/src/client/name_resolution/dns/test.rs +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -25,12 +25,12 @@ use crate::{ self, backoff::{BackoffConfig, DEFAULT_EXPONENTIAL_CONFIG}, dns::{parse_endpoint_and_authority, HostPort}, - global_registry, ResolverOptions, ResolverUpdate, Target, + global_registry, Resolver, ResolverOptions, ResolverUpdate, Target, }, rt::{self, tokio::TokioRuntime}, }; -use super::ParseResult; +use super::{DnsOptions, ParseResult}; const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10); @@ -323,7 +323,6 @@ pub async fn dns_lookup_error() { #[tokio::test] pub async fn dns_lookup_timeout() { - let target = &"dns:///grpc.io:1234".parse().unwrap(); let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(WorkScheduler { work_tx: work_tx.clone(), @@ -335,17 +334,20 @@ pub async fn dns_lookup_timeout() { lookup_result: Ok(Vec::new()), }, }; + let dns_client = runtime.dns.clone(); let opts = ResolverOptions { authority: "ignored".to_string(), runtime: Arc::new(runtime), work_scheduler: work_scheduler.clone(), }; - let dns_opts = super::DnsOptions { + let dns_opts = DnsOptions { min_resolution_interval: super::get_min_resolution_interval(), resolving_timeout: DEFAULT_TEST_SHORT_TIMEOUT, backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + host: "grpc.io".to_string(), + port: 1234, }; - let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + let mut resolver = super::DnsResolver::new(Box::new(dns_client), opts, dns_opts); // Wait for schedule work to be called. let _ = work_rx.recv().await.unwrap(); @@ -363,7 +365,6 @@ pub async fn dns_lookup_timeout() { #[tokio::test] pub async fn rate_limit() { - let target = &"dns:///localhost:1234".parse().unwrap(); let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(WorkScheduler { work_tx: work_tx.clone(), @@ -373,12 +374,18 @@ pub async fn rate_limit() { runtime: Arc::new(TokioRuntime {}), work_scheduler: work_scheduler.clone(), }; - let dns_opts = super::DnsOptions { + let dns_client = opts + .runtime + .get_dns_resolver(rt::ResolverOptions { server_addr: None }) + .unwrap(); + let dns_opts = DnsOptions { min_resolution_interval: Duration::from_secs(20), resolving_timeout: super::get_resolving_timeout(), backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + host: "localhost".to_string(), + port: 1234, }; - let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + let mut resolver = super::DnsResolver::new(dns_client, opts, dns_opts); // Wait for schedule work to be called. let event = work_rx.recv().await.unwrap(); @@ -408,7 +415,6 @@ pub async fn rate_limit() { #[tokio::test] pub async fn re_resolution_after_success() { - let target = &"dns:///localhost:1234".parse().unwrap(); let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(WorkScheduler { work_tx: work_tx.clone(), @@ -418,12 +424,18 @@ pub async fn re_resolution_after_success() { runtime: Arc::new(TokioRuntime {}), work_scheduler: work_scheduler.clone(), }; - let dns_opts = super::DnsOptions { + let dns_opts = DnsOptions { min_resolution_interval: Duration::from_millis(1), resolving_timeout: super::get_resolving_timeout(), backoff_config: DEFAULT_EXPONENTIAL_CONFIG, + host: "localhost".to_string(), + port: 1234, }; - let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + let dns_client = opts + .runtime + .get_dns_resolver(rt::ResolverOptions { server_addr: None }) + .unwrap(); + let mut resolver = super::DnsResolver::new(dns_client, opts, dns_opts); // Wait for schedule work to be called. let _ = work_rx.recv().await.unwrap(); @@ -447,7 +459,6 @@ pub async fn re_resolution_after_success() { #[tokio::test] pub async fn backoff_on_error() { - let target = &"dns:///localhost:1234".parse().unwrap(); let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(WorkScheduler { work_tx: work_tx.clone(), @@ -457,7 +468,7 @@ pub async fn backoff_on_error() { runtime: Arc::new(TokioRuntime {}), work_scheduler: work_scheduler.clone(), }; - let dns_opts = super::DnsOptions { + let dns_opts = DnsOptions { min_resolution_interval: Duration::from_millis(1), resolving_timeout: super::get_resolving_timeout(), // Speed up the backoffs to make the test run faster. @@ -467,8 +478,15 @@ pub async fn backoff_on_error() { jitter: 0.0, max_delay: Duration::from_millis(1), }, + host: "localhost".to_string(), + port: 1234, }; - let mut resolver = super::DnsResolver::new(target, opts, dns_opts); + let dns_client = opts + .runtime + .get_dns_resolver(rt::ResolverOptions { server_addr: None }) + .unwrap(); + + let mut resolver = super::DnsResolver::new(dns_client, opts, dns_opts); let (update_tx, mut update_rx) = mpsc::unbounded_channel(); let mut channel_controller = FakeChannelController { diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 5c808250c..282232f1a 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -121,9 +121,13 @@ pub trait ResolverBuilder: Send + Sync { fn scheme(&self) -> &'static str; /// Returns the default authority for a channel using this name resolver - /// and target. This is typically the same as the service's name. By - /// default, the default_authority method automatically returns the path - /// portion of the target URI, with the leading prefix removed. + /// and target. This refers to the *dataplane authority* — the value used + /// in the `:authority` header of HTTP/2 requests — and not to be confused + /// with the authority portion of the target URI, which typically specifies + /// the name of an external server used for name resolution. + /// + /// By default, this method returns the path portion of the target URI, + /// with the leading prefix removed. fn default_authority(&self, target: &Target) -> String { let path = target.path(); path.strip_prefix("/").unwrap_or(path).to_string() @@ -138,10 +142,14 @@ pub trait ResolverBuilder: Send + Sync { /// name resolver. #[non_exhaustive] pub struct ResolverOptions { - /// The authority that will be used for the channel by default. This - /// contains either the result of the default_authority method of this - /// ResolverBuilder, or another string if the channel was configured to - /// override the default. + /// The authority that will be used for the channel by default. This refers + /// to the `:authority` value sent in HTTP/2 requests — the dataplane + /// authority — and not the authority portion of the target URI, which is + /// typically used to identify the name resolution server. + /// + /// This value is either the result of the `default_authority` method of + /// this `ResolverBuilder`, or another string if the channel was explicitly + /// configured to override the default. pub authority: String, /// The runtime which provides utilities to do async work. @@ -163,19 +171,19 @@ pub trait WorkScheduler: Send + Sync { /// Resolver watches for the updates on the specified target. /// Updates include address updates and service config updates. pub trait Resolver: Send { - /// Asks the resolver to obtain an updated resolver result, if - /// applicable. + /// Asks the resolver to obtain an updated resolver result, if applicable. /// /// This is useful for polling resolvers to decide when to re-resolve. - /// However, the implementation is not required to - /// re-resolve immediately upon receiving this call; it may instead - /// elect to delay based on some configured minimum time between - /// queries, to avoid hammering the name service with queries. + /// However, the implementation is not required to re-resolve immediately + /// upon receiving this call; it may instead elect to delay based on some + /// configured minimum time between queries, to avoid hammering the name + /// service with queries. /// /// For watch based resolvers, this may be a no-op. fn resolve_now(&mut self); - /// Called serially by the channel to to allow access to ChannelController. + /// Called serially by the channel to provide access to the + /// `ChannelController`. fn work(&mut self, channel_controller: &mut dyn ChannelController); } diff --git a/grpc/src/client/name_resolution/registry.rs b/grpc/src/client/name_resolution/registry.rs index fa4a42163..bb76bd52c 100644 --- a/grpc/src/client/name_resolution/registry.rs +++ b/grpc/src/client/name_resolution/registry.rs @@ -59,15 +59,11 @@ impl ResolverRegistry { /// The provided scheme is case-insensitive; any uppercase characters /// will be converted to lowercase before lookup. pub fn get(&self, scheme: &str) -> Option> { - self.m - .lock() - .unwrap() - .get(&scheme.to_lowercase()) - .map(|b| b.clone()) + self.m.lock().unwrap().get(&scheme.to_lowercase()).cloned() } } /// Global registry for resolver builders. pub fn global_registry() -> &'static ResolverRegistry { - GLOBAL_RESOLVER_REGISTRY.get_or_init(|| ResolverRegistry::new()) + GLOBAL_RESOLVER_REGISTRY.get_or_init(ResolverRegistry::new) }