Skip to content

Commit 6287844

Browse files
committed
feat: implement cancel handler for postgres
1 parent cbd5240 commit 6287844

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

src/servers/src/postgres.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ use std::net::SocketAddr;
3030
use std::sync::Arc;
3131

3232
use ::auth::UserProviderRef;
33+
use catalog::process_manager::ProcessManagerRef;
3334
use derive_builder::Builder;
3435
use pgwire::api::auth::ServerParameterProvider;
3536
use pgwire::api::auth::StartupHandler;
37+
use pgwire::api::cancel::CancelHandler;
3638
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
3739
use pgwire::api::ErrorHandler;
3840
use pgwire::api::{ClientInfo, PgWireServerHandlers};
@@ -73,6 +75,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
7375

7476
pub struct PostgresServerHandlerInner {
7577
query_handler: ServerSqlQueryHandlerRef,
78+
process_manager: ProcessManagerRef,
7679
login_verifier: PgLoginVerifier,
7780
force_tls: bool,
7881
param_provider: Arc<GreptimeDBStartupParameters>,
@@ -84,6 +87,7 @@ pub struct PostgresServerHandlerInner {
8487
#[derive(Builder)]
8588
pub(crate) struct MakePostgresServerHandler {
8689
query_handler: ServerSqlQueryHandlerRef,
90+
process_manager: ProcessManagerRef,
8791
user_provider: Option<UserProviderRef>,
8892
#[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")]
8993
param_provider: Arc<GreptimeDBStartupParameters>,
@@ -108,10 +112,15 @@ impl PgWireServerHandlers for PostgresServerHandler {
108112
fn error_handler(&self) -> Arc<impl ErrorHandler> {
109113
self.0.clone()
110114
}
115+
116+
fn cancel_handler(&self) -> Arc<impl CancelHandler> {
117+
self.0.clone()
118+
}
111119
}
112120

113121
impl MakePostgresServerHandler {
114-
fn make(&self, addr: Option<SocketAddr>, process_id: u32) -> PostgresServerHandler {
122+
fn make(&self, addr: Option<SocketAddr>) -> PostgresServerHandler {
123+
let process_id = self.process_manager.next_id();
115124
let session = Arc::new(Session::new(
116125
addr,
117126
Channel::Postgres,
@@ -120,6 +129,7 @@ impl MakePostgresServerHandler {
120129
));
121130
let handler = PostgresServerHandlerInner {
122131
query_handler: self.query_handler.clone(),
132+
process_manager: self.process_manager.clone(),
123133
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
124134
force_tls: self.force_tls,
125135
param_provider: self.param_provider.clone(),

src/servers/src/postgres/handler.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use datafusion_common::ParamValues;
2525
use datatypes::prelude::ConcreteDataType;
2626
use datatypes::schema::SchemaRef;
2727
use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
28+
use pgwire::api::cancel::CancelHandler;
2829
use pgwire::api::portal::{Format, Portal};
2930
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
3031
use pgwire::api::results::{
@@ -33,6 +34,7 @@ use pgwire::api::results::{
3334
use pgwire::api::stmt::{QueryParser, StoredStatement};
3435
use pgwire::api::{ClientInfo, ErrorHandler, Type};
3536
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
37+
use pgwire::messages::cancel::CancelRequest;
3638
use pgwire::messages::PgWireBackendMessage;
3739
use query::query_engine::DescribeResult;
3840
use session::context::QueryContextRef;
@@ -424,3 +426,17 @@ impl ErrorHandler for PostgresServerHandlerInner {
424426
debug!("Postgres interface error {}", error)
425427
}
426428
}
429+
430+
#[async_trait]
431+
impl CancelHandler for PostgresServerHandlerInner {
432+
async fn on_cancel_request(&self, cancel_request: CancelRequest) {
433+
let pid = cancel_request.pid as u32;
434+
let _secret_key = cancel_request.secret_key;
435+
436+
// FIXME:
437+
let _ = self
438+
.process_manager
439+
.kill_process("todo".to_string(), "todo".to_string(), pid)
440+
.await;
441+
}
442+
}

src/servers/src/postgres/server.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ pub struct PostgresServer {
3838
tls_server_config: Arc<ReloadableTlsServerConfig>,
3939
keep_alive_secs: u64,
4040
bind_addr: Option<SocketAddr>,
41-
process_manager: Option<ProcessManagerRef>,
4241
}
4342

4443
impl PostgresServer {
@@ -50,11 +49,12 @@ impl PostgresServer {
5049
keep_alive_secs: u64,
5150
io_runtime: Runtime,
5251
user_provider: Option<UserProviderRef>,
53-
process_manager: Option<ProcessManagerRef>,
52+
process_manager: ProcessManagerRef,
5453
) -> PostgresServer {
5554
let make_handler = Arc::new(
5655
MakePostgresServerHandlerBuilder::default()
5756
.query_handler(query_handler.clone())
57+
.process_manager(process_manager.clone())
5858
.user_provider(user_provider.clone())
5959
.force_tls(force_tls)
6060
.build()
@@ -66,7 +66,6 @@ impl PostgresServer {
6666
tls_server_config,
6767
keep_alive_secs,
6868
bind_addr: None,
69-
process_manager,
7069
}
7170
}
7271

@@ -77,12 +76,10 @@ impl PostgresServer {
7776
) -> impl Future<Output = ()> {
7877
let handler_maker = self.make_handler.clone();
7978
let tls_server_config = self.tls_server_config.clone();
80-
let process_manager = self.process_manager.clone();
8179
accepting_stream.for_each(move |tcp_stream| {
8280
let io_runtime = io_runtime.clone();
8381
let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
8482
let handler_maker = handler_maker.clone();
85-
let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
8683

8784
async move {
8885
match tcp_stream {
@@ -101,7 +98,7 @@ impl PostgresServer {
10198

10299
let _handle = io_runtime.spawn(async move {
103100
crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
104-
let pg_handler = Arc::new(handler_maker.make(addr, process_id));
101+
let pg_handler = Arc::new(handler_maker.make(addr));
105102
let r =
106103
process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
107104
crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();

0 commit comments

Comments
 (0)