diff --git a/Cargo.lock b/Cargo.lock index 76f85c88b141..ebfc412bdc67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1259,16 +1259,6 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" -[[package]] -name = "bcder" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627747a6774aab38beb35990d88309481378558875a41da1a4b2e373c906ef0" -dependencies = [ - "bytes", - "smallvec", -] - [[package]] name = "bigdecimal" version = "0.3.1" @@ -1648,6 +1638,7 @@ dependencies = [ "partition", "paste", "prometheus", + "rand 0.9.0", "rustc-hash 2.0.0", "serde_json", "session", @@ -3874,10 +3865,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" dependencies = [ "const-oid", + "der_derive", + "flagset", "pem-rfc7468", "zeroize", ] +[[package]] +name = "der_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8034092389675178f570469e6c3b0465d3d30b4505c294a6550db47f3c17ad18" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "deranged" version = "0.3.11" @@ -7055,6 +7059,12 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + [[package]] name = "measure_time" version = "0.8.3" @@ -8101,7 +8111,7 @@ dependencies = [ "common-test-util", "futures", "lazy_static", - "md5", + "md5 0.7.0", "moka", "opendal", "prometheus", @@ -8923,8 +8933,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.30.2" -source = "git+https://github.com/sunng87/pgwire?rev=127573d997228cfb70c7699881c568eae8131270#127573d997228cfb70c7699881c568eae8131270" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "449fecabd6a04033ec9c12e6c0bb7e663e03c3731f59d1e196c1ae9f1b65a9a9" dependencies = [ "async-trait", "bytes", @@ -8933,7 +8944,7 @@ dependencies = [ "futures", "hex", "lazy-regex", - "md5", + "md5 0.8.0", "postgres-types", "rand 0.9.0", "ring", @@ -12904,6 +12915,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tls_codec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive", + "zeroize", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "tokio" version = "1.44.2" @@ -12997,16 +13029,17 @@ dependencies = [ [[package]] name = "tokio-postgres-rustls" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab" +checksum = "27d684bad428a0f2481f42241f821db42c54e2dc81d8c00db8536c506b0a0144" dependencies = [ + "const-oid", "ring", "rustls", "tokio", "tokio-postgres", "tokio-rustls", - "x509-certificate", + "x509-cert", ] [[package]] @@ -14556,22 +14589,15 @@ dependencies = [ ] [[package]] -name = "x509-certificate" -version = "0.23.1" +name = "x509-cert" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66534846dec7a11d7c50a74b7cdb208b9a581cad890b7866430d438455847c85" +checksum = "1301e935010a701ae5f8655edc0ad17c44bad3ac5ce8c39185f75453b720ae94" dependencies = [ - "bcder", - "bytes", - "chrono", + "const-oid", "der", - "hex", - "pem", - "ring", - "signature", "spki", - "thiserror 1.0.64", - "zeroize", + "tls_codec", ] [[package]] diff --git a/src/catalog/Cargo.toml b/src/catalog/Cargo.toml index c7e2782c0e26..7b995ff06b0a 100644 --- a/src/catalog/Cargo.toml +++ b/src/catalog/Cargo.toml @@ -43,6 +43,7 @@ moka = { workspace = true, features = ["future", "sync"] } partition.workspace = true paste.workspace = true prometheus.workspace = true +rand.workspace = true rustc-hash.workspace = true serde_json.workspace = true session.workspace = true diff --git a/src/catalog/src/process_manager.rs b/src/catalog/src/process_manager.rs index ff2db26f46bd..bd1bdabb7a25 100644 --- a/src/catalog/src/process_manager.rs +++ b/src/catalog/src/process_manager.rs @@ -210,6 +210,10 @@ impl ProcessManager { Ok(false) } } + + pub fn server_addr(&self) -> &str { + &self.server_addr + } } pub struct Ticket { diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index bc64219e1157..3a891299ed7f 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -258,7 +258,7 @@ where opts.keep_alive.as_secs(), common_runtime::global_runtime(), user_provider.clone(), - Some(self.instance.process_manager().clone()), + self.instance.process_manager().clone(), )) as Box; handlers.insert((pg_server, pg_addr)); diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index bc64e19485cc..09c0c3384c4f 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -88,10 +88,7 @@ opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "a1fb4d opentelemetry-proto.workspace = true otel-arrow-rust.workspace = true parking_lot.workspace = true -#pgwire = { version = "0.30", default-features = false, features = ["server-api-ring"] } -pgwire = { git = "https://github.com/sunng87/pgwire", rev = "127573d997228cfb70c7699881c568eae8131270", default-features = false, features = [ - "server-api-ring", -] } +pgwire = { version = "0.31", default-features = false, features = ["server-api-ring"] } pin-project = "1.0" pipeline.workspace = true postgres-types = { version = "0.2", features = ["with-chrono-0_4", "with-serde_json-1"] } @@ -149,7 +146,7 @@ session = { workspace = true, features = ["testing"] } table.workspace = true tempfile = "3.0.0" tokio-postgres = "0.7" -tokio-postgres-rustls = "0.12" +tokio-postgres-rustls = "0.13" [target.'cfg(unix)'.dev-dependencies] pprof = { version = "0.14", features = ["criterion", "flamegraph"] } diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index 9ae323478518..dca6baacb262 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -29,11 +29,13 @@ use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use ::auth::UserProviderRef; +use auth::UserProviderRef; +use catalog::process_manager::ProcessManagerRef; use derive_builder::Builder; -use pgwire::api::auth::ServerParameterProvider; -use pgwire::api::copy::NoopCopyHandler; -use pgwire::api::{ClientInfo, PgWireServerHandlers}; +use pgwire::api::auth::{ServerParameterProvider, StartupHandler}; +use pgwire::api::cancel::CancelHandler; +use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; +use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers}; pub use server::PostgresServer; use session::context::Channel; use session::Session; @@ -71,6 +73,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters { pub struct PostgresServerHandlerInner { query_handler: ServerSqlQueryHandlerRef, + process_manager: ProcessManagerRef, login_verifier: PgLoginVerifier, force_tls: bool, param_provider: Arc, @@ -82,6 +85,7 @@ pub struct PostgresServerHandlerInner { #[derive(Builder)] pub(crate) struct MakePostgresServerHandler { query_handler: ServerSqlQueryHandlerRef, + process_manager: ProcessManagerRef, user_provider: Option, #[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")] param_provider: Arc, @@ -91,43 +95,40 @@ pub(crate) struct MakePostgresServerHandler { pub(crate) struct PostgresServerHandler(Arc); impl PgWireServerHandlers for PostgresServerHandler { - type StartupHandler = PostgresServerHandlerInner; - type SimpleQueryHandler = PostgresServerHandlerInner; - type ExtendedQueryHandler = PostgresServerHandlerInner; - type CopyHandler = NoopCopyHandler; - type ErrorHandler = PostgresServerHandlerInner; - - fn simple_query_handler(&self) -> Arc { + fn simple_query_handler(&self) -> Arc { self.0.clone() } - fn extended_query_handler(&self) -> Arc { + fn extended_query_handler(&self) -> Arc { self.0.clone() } - fn startup_handler(&self) -> Arc { + fn startup_handler(&self) -> Arc { self.0.clone() } - fn copy_handler(&self) -> Arc { - Arc::new(NoopCopyHandler) + fn error_handler(&self) -> Arc { + self.0.clone() } - fn error_handler(&self) -> Arc { + fn cancel_handler(&self) -> Arc { self.0.clone() } } impl MakePostgresServerHandler { - fn make(&self, addr: Option, process_id: u32) -> PostgresServerHandler { - let session = Arc::new(Session::new( - addr, - Channel::Postgres, - Default::default(), - process_id, - )); + fn make(&self, addr: Option) -> PostgresServerHandler { + let process_id = self.process_manager.next_id(); + let secret_key = rand::random(); + + let session = Arc::new( + Session::new(addr, Channel::Postgres, Default::default(), process_id) + .with_secret_key(secret_key), + ); + let handler = PostgresServerHandlerInner { query_handler: self.query_handler.clone(), + process_manager: self.process_manager.clone(), login_verifier: PgLoginVerifier::new(self.user_provider.clone()), force_tls: self.force_tls, param_provider: self.param_provider.clone(), diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 9505c119565d..54eed613b1da 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -17,6 +17,7 @@ use std::sync::Exclusive; use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef}; use async_trait::async_trait; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use futures::{Sink, SinkExt}; @@ -24,9 +25,8 @@ use pgwire::api::auth::StartupHandler; use pgwire::api::{auth, ClientInfo, PgWireConnectionState}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::response::ErrorResponse; -use pgwire::messages::startup::Authentication; -use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; -use session::Session; +use pgwire::messages::startup::{Authentication, SecretKey}; +use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage, ProtocolVersion}; use snafu::IntoError; use crate::error::{AuthSnafu, Result}; @@ -113,21 +113,98 @@ impl PgLoginVerifier { } } -fn set_client_info(client: &mut C, session: &Session) -where - C: ClientInfo, -{ - if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) { - session.set_catalog(current_catalog.clone()); +fn do_encode_pg_secret_key_bytes(secret_key: i32, server_addr: &str, catalog: &str) -> Vec { + let mut bytes = BytesMut::with_capacity(256); + + bytes.put_i32(secret_key); + + bytes.put_u8(server_addr.len() as u8); + bytes.put_u8(catalog.len() as u8); + + bytes.put_slice(server_addr.as_bytes()); + bytes.put_slice(catalog.as_bytes()); + + bytes.freeze().to_vec() +} + +fn do_decode_pg_secret_key_bytes(mut buf: Bytes) -> Option<(i32, String, String)> { + // this byte block should be at least 6-byte len + if buf.remaining() > 6 { + // get the i32 key + let key = buf.get_i32(); + // get server addr len + let server_addr_len = buf.get_u8() as usize; + // get catalog len + let catalog_len = buf.get_u8() as usize; + + if buf.remaining() >= server_addr_len + catalog_len { + let server_addr = String::from_utf8_lossy(&buf.split_to(server_addr_len)).into(); + let catalog = String::from_utf8_lossy(&buf.split_to(catalog_len)).into(); + + Some((key, server_addr, catalog)) + } else { + None + } + } else { + None + } +} + +impl PostgresServerHandlerInner { + /// Generate a Postgres specific secret key + /// + /// The secret key has to carry enough information to call `kill_process` in + /// a distributed setup. It has to carry: + /// + /// - A random i32 number + /// - the local frontend server address: u8(byte length) + bytes + /// - the catalog that client has authenticated: u8(byte length) + bytes + /// + /// The final byte content is: + /// + /// int32(secret_key) + u8(server_addr len) + u8(catalog len) + /// + server_addr... + catalog... + /// + /// According to Postgres spec, the key should carry less than 256 bytes + pub fn encode_secret_key_bytes(&self) -> Vec { + do_encode_pg_secret_key_bytes( + self.session.secret_key().unwrap_or(0), + self.process_manager.server_addr(), + &self.session.catalog(), + ) } - if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) { - session.set_schema(current_schema.clone()); + + /// Validate and decode secret key into (catalog, frontend) tuple + pub fn decode_secret_key(&self, secret_key_bytes: Bytes) -> Option<(i32, String, String)> { + do_decode_pg_secret_key_bytes(secret_key_bytes) } - // pass generated process id and secret key to client, this information will - // be sent to postgres client for query cancellation. - client.set_pid_and_secret_key(session.process_id() as i32, rand::random::()); - // set userinfo outside + fn set_client_info(&self, client: &mut C) + where + C: ClientInfo, + { + if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) { + self.session.set_catalog(current_catalog.clone()); + } + if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) { + self.session.set_schema(current_schema.clone()); + } + + // pass generated process id and secret key to client, this information will + // be sent to postgres client for query cancellation. + if client.protocol_version() == ProtocolVersion::PROTOCOL3_0 { + // 3.0 protocol is not supported for cancel, we give client all 0 + client.set_pid_and_secret_key(0, SecretKey::I32(0)); + } else { + let secret_key_bytes = self.encode_secret_key_bytes(); + client.set_pid_and_secret_key( + self.session.process_id() as i32, + SecretKey::Bytes(Bytes::copy_from_slice(&secret_key_bytes)), + ); + } + + // set userinfo outside + } } #[async_trait] @@ -154,6 +231,8 @@ impl StartupHandler for PostgresServerHandlerInner { return Ok(()); } + // performance postgres protocol negotiation + auth::protocol_negotiation(client, startup).await?; auth::save_startup_parameters_to_metadata(client, startup); // check if db is valid @@ -180,7 +259,7 @@ impl StartupHandler for PostgresServerHandlerInner { self.session.set_user_info(userinfo_by_name( client.metadata().get(super::METADATA_USER).cloned(), )); - set_client_info(client, &self.session); + self.set_client_info(client); auth::finish_authentication(client, self.param_provider.as_ref()).await?; } } @@ -197,7 +276,7 @@ impl StartupHandler for PostgresServerHandlerInner { if let Ok(Some(user_info)) = auth_result { self.session.set_user_info(user_info); - set_client_info(client, &self.session); + self.set_client_info(client); auth::finish_authentication(client, self.param_provider.as_ref()).await?; } else { return send_error( @@ -257,3 +336,20 @@ where Ok(DbResolution::NotFound("Database not specified".to_owned())) } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_secret_key_roundtrip() { + let tuple = (3244, "10.0.0.23", "greptime"); + let bytes = do_encode_pg_secret_key_bytes(tuple.0, tuple.1, tuple.2); + let decoded = do_decode_pg_secret_key_bytes(Bytes::copy_from_slice(&bytes)) + .expect("failed to decode secret key"); + + assert_eq!(tuple.0, decoded.0); + assert_eq!(tuple.1, decoded.1); + assert_eq!(tuple.2, decoded.2); + } +} diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 97c48a8ac98c..4079dc3c71ac 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -25,6 +25,7 @@ use datafusion_common::ParamValues; use datatypes::prelude::ConcreteDataType; use datatypes::schema::SchemaRef; use futures::{future, stream, Sink, SinkExt, Stream, StreamExt}; +use pgwire::api::cancel::CancelHandler; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ @@ -33,6 +34,8 @@ use pgwire::api::results::{ use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::{ClientInfo, ErrorHandler, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::cancel::CancelRequest; +use pgwire::messages::startup::SecretKey; use pgwire::messages::PgWireBackendMessage; use query::query_engine::DescribeResult; use session::context::QueryContextRef; @@ -424,3 +427,22 @@ impl ErrorHandler for PostgresServerHandlerInner { debug!("Postgres interface error {}", error) } } + +#[async_trait] +impl CancelHandler for PostgresServerHandlerInner { + async fn on_cancel_request(&self, cancel_request: CancelRequest) { + let pid = cancel_request.pid as u32; + + // We don't support i32 secret key even if it seems workable on + // standalone setup. + if let SecretKey::Bytes(secret_key) = cancel_request.secret_key { + if let Some((_key, server_addr, catalog)) = self.decode_secret_key(secret_key) { + //TODO(sunng87): verify _key + let _ = self + .process_manager + .kill_process(server_addr, catalog, pid) + .await; + } + } + } +} diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index a509771fcf6d..c8aaa680aeca 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -38,7 +38,6 @@ pub struct PostgresServer { tls_server_config: Arc, keep_alive_secs: u64, bind_addr: Option, - process_manager: Option, } impl PostgresServer { @@ -50,11 +49,12 @@ impl PostgresServer { keep_alive_secs: u64, io_runtime: Runtime, user_provider: Option, - process_manager: Option, + process_manager: ProcessManagerRef, ) -> PostgresServer { let make_handler = Arc::new( MakePostgresServerHandlerBuilder::default() .query_handler(query_handler.clone()) + .process_manager(process_manager.clone()) .user_provider(user_provider.clone()) .force_tls(force_tls) .build() @@ -66,7 +66,6 @@ impl PostgresServer { tls_server_config, keep_alive_secs, bind_addr: None, - process_manager, } } @@ -77,12 +76,10 @@ impl PostgresServer { ) -> impl Future { let handler_maker = self.make_handler.clone(); let tls_server_config = self.tls_server_config.clone(); - let process_manager = self.process_manager.clone(); accepting_stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from); let handler_maker = handler_maker.clone(); - let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0); async move { match tcp_stream { @@ -101,7 +98,7 @@ impl PostgresServer { let _handle = io_runtime.spawn(async move { crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc(); - let pg_handler = Arc::new(handler_maker.make(addr, process_id)); + let pg_handler = Arc::new(handler_maker.make(addr)); let r = process_socket(io_stream, tls_acceptor.clone(), pg_handler).await; crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec(); diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 0daffdc32aab..0cc6446fa14a 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -18,6 +18,7 @@ use std::time::Duration; use auth::tests::{DatabaseAuthInfo, MockUserProvider}; use auth::UserProviderRef; +use catalog::process_manager::ProcessManager; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_runtime::runtime::BuilderBuild; use common_runtime::Builder as RuntimeBuilder; @@ -71,7 +72,7 @@ fn create_postgres_server( 0, io_runtime, user_provider, - None, + Arc::new(ProcessManager::new("127.0.0.1".to_string(), None)), ))) } diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 7688b0b6599f..a3f304a11371 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -43,6 +43,8 @@ pub struct Session { configuration_variables: Arc, // the process id to use when killing the query process_id: u32, + // a postgres specific key for cancel request + secret_key: Option, } pub type SessionRef = Arc; @@ -85,9 +87,15 @@ impl Session { configuration_variables: Arc::new(configuration_variables), mutable_inner: Arc::new(RwLock::new(MutableInner::default())), process_id, + secret_key: None, } } + pub fn with_secret_key(mut self, secret_key: i32) -> Self { + self.secret_key = Some(secret_key); + self + } + pub fn new_query_context(&self) -> QueryContextRef { QueryContextBuilder::default() // catalog is not allowed for update in query context so we use @@ -155,4 +163,8 @@ impl Session { pub fn process_id(&self) -> u32 { self.process_id } + + pub fn secret_key(&self) -> Option { + self.secret_key + } } diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 8cc093224087..32647036f7a2 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use auth::UserProviderRef; use axum::Router; use catalog::kvbackend::KvBackendCatalogManager; +use catalog::process_manager::ProcessManager; use common_base::secrets::ExposeSecret; use common_config::Configurable; use common_meta::key::catalog_name::CatalogNameKey; @@ -698,7 +699,7 @@ pub async fn setup_pg_server_with_user_provider( 0, runtime, user_provider, - None, + Arc::new(ProcessManager::new("127.0.0.1".to_string(), None)), )); pg_server