Skip to content

Commit c7d6273

Browse files
authored
Support for prepared statements (#474)
* Start prepared statements * parse * Ok * optional * dont rewrite anonymous prepared stmts * Dont rewrite anonymous prep statements * hm? * prep statements * I see! * comment * Print config value * Rewrite bind and add sqlx test * fmt * ok * Fix * Fix stats * its late * clean up PREPARE
1 parent 94c7818 commit c7d6273

File tree

14 files changed

+1954
-10
lines changed

14 files changed

+1954
-10
lines changed

dev/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM rust:bullseye
1+
FROM rust:1.70-bullseye
22

33
# Dependencies
44
RUN apt-get update -y \

pgcat.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ tcp_keepalives_count = 5
6060
# Number of seconds between keepalive packets.
6161
tcp_keepalives_interval = 5
6262

63+
# Handle prepared statements.
64+
prepared_statements = true
65+
6366
# Path to TLS Certificate file to use for TLS connections
6467
# tls_certificate = ".circleci/server.cert"
6568
# Path to TLS private key file to use for TLS connections

src/admin.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,8 @@ where
699699
("bytes_sent", DataType::Numeric),
700700
("bytes_received", DataType::Numeric),
701701
("age_seconds", DataType::Numeric),
702+
("prepare_cache_hit", DataType::Numeric),
703+
("prepare_cache_miss", DataType::Numeric),
702704
];
703705

704706
let new_map = get_server_stats();
@@ -722,6 +724,14 @@ where
722724
.duration_since(server.connect_time())
723725
.as_secs()
724726
.to_string(),
727+
server
728+
.prepared_hit_count
729+
.load(Ordering::Relaxed)
730+
.to_string(),
731+
server
732+
.prepared_miss_count
733+
.load(Ordering::Relaxed)
734+
.to_string(),
725735
];
726736

727737
res.put(data_row(&row));

src/client.rs

Lines changed: 209 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use crate::pool::BanReason;
33
/// Handle clients by pretending to be a PostgreSQL server.
44
use bytes::{Buf, BufMut, BytesMut};
55
use log::{debug, error, info, trace, warn};
6+
use once_cell::sync::Lazy;
67
use std::collections::HashMap;
7-
use std::sync::Arc;
8+
use std::sync::{atomic::AtomicUsize, Arc};
89
use std::time::Instant;
910
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
1011
use tokio::net::TcpStream;
@@ -13,7 +14,9 @@ use tokio::sync::mpsc::Sender;
1314

1415
use crate::admin::{generate_server_info_for_admin, handle_admin};
1516
use crate::auth_passthrough::refetch_auth_hash;
16-
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
17+
use crate::config::{
18+
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
19+
};
1720
use crate::constants::*;
1821
use crate::messages::*;
1922
use crate::plugins::PluginOutput;
@@ -25,6 +28,11 @@ use crate::tls::Tls;
2528

2629
use tokio_rustls::server::TlsStream;
2730

31+
/// Incrementally count prepared statements
32+
/// to avoid random conflicts in places where the random number generator is weak.
33+
pub static PREPARED_STATEMENT_COUNTER: Lazy<Arc<AtomicUsize>> =
34+
Lazy::new(|| Arc::new(AtomicUsize::new(0)));
35+
2836
/// Type of connection received from client.
2937
enum ClientConnectionType {
3038
Startup,
@@ -93,6 +101,9 @@ pub struct Client<S, T> {
93101

94102
/// Used to notify clients about an impending shutdown
95103
shutdown: Receiver<()>,
104+
105+
/// Prepared statements
106+
prepared_statements: HashMap<String, Parse>,
96107
}
97108

98109
/// Client entrypoint.
@@ -682,6 +693,7 @@ where
682693
application_name: application_name.to_string(),
683694
shutdown,
684695
connected_to_server: false,
696+
prepared_statements: HashMap::new(),
685697
})
686698
}
687699

@@ -716,6 +728,7 @@ where
716728
application_name: String::from("undefined"),
717729
shutdown,
718730
connected_to_server: false,
731+
prepared_statements: HashMap::new(),
719732
})
720733
}
721734

@@ -757,6 +770,10 @@ where
757770
// Result returned by one of the plugins.
758771
let mut plugin_output = None;
759772

773+
// Prepared statement being executed
774+
let mut prepared_statement = None;
775+
let mut will_prepare = false;
776+
760777
// Our custom protocol loop.
761778
// We expect the client to either start a transaction with regular queries
762779
// or issue commands for our sharding and server selection protocol.
@@ -766,13 +783,16 @@ where
766783
self.transaction_mode
767784
);
768785

786+
// Should we rewrite prepared statements and bind messages?
787+
let mut prepared_statements_enabled = get_prepared_statements();
788+
769789
// Read a complete message from the client, which normally would be
770790
// either a `Q` (query) or `P` (prepare, extended protocol).
771791
// We can parse it here before grabbing a server from the pool,
772792
// in case the client is sending some custom protocol messages, e.g.
773793
// SET SHARDING KEY TO 'bigint';
774794

775-
let message = tokio::select! {
795+
let mut message = tokio::select! {
776796
_ = self.shutdown.recv() => {
777797
if !self.admin {
778798
error_response_terminal(
@@ -800,7 +820,21 @@ where
800820
// allocate a connection, we wouldn't be able to send back an error message
801821
// to the client so we buffer them and defer the decision to error out or not
802822
// to when we get the S message
803-
'D' | 'E' => {
823+
'D' => {
824+
if prepared_statements_enabled {
825+
let name;
826+
(name, message) = self.rewrite_describe(message).await?;
827+
828+
if let Some(name) = name {
829+
prepared_statement = Some(name);
830+
}
831+
}
832+
833+
self.buffer.put(&message[..]);
834+
continue;
835+
}
836+
837+
'E' => {
804838
self.buffer.put(&message[..]);
805839
continue;
806840
}
@@ -830,6 +864,11 @@ where
830864
}
831865

832866
'P' => {
867+
if prepared_statements_enabled {
868+
(prepared_statement, message) = self.rewrite_parse(message)?;
869+
will_prepare = true;
870+
}
871+
833872
self.buffer.put(&message[..]);
834873

835874
if query_router.query_parser_enabled() {
@@ -846,6 +885,10 @@ where
846885
}
847886

848887
'B' => {
888+
if prepared_statements_enabled {
889+
(prepared_statement, message) = self.rewrite_bind(message).await?;
890+
}
891+
849892
self.buffer.put(&message[..]);
850893

851894
if query_router.query_parser_enabled() {
@@ -1054,7 +1097,48 @@ where
10541097
// If the client is in session mode, no more custom protocol
10551098
// commands will be accepted.
10561099
loop {
1057-
let message = match initial_message {
1100+
// Only check if we should rewrite prepared statements
1101+
// in session mode. In transaction mode, we check at the beginning of
1102+
// each transaction.
1103+
if !self.transaction_mode {
1104+
prepared_statements_enabled = get_prepared_statements();
1105+
}
1106+
1107+
debug!("Prepared statement active: {:?}", prepared_statement);
1108+
1109+
// We are processing a prepared statement.
1110+
if let Some(ref name) = prepared_statement {
1111+
debug!("Checking prepared statement is on server");
1112+
// Get the prepared statement the server expects to see.
1113+
let statement = match self.prepared_statements.get(name) {
1114+
Some(statement) => {
1115+
debug!("Prepared statement `{}` found in cache", name);
1116+
statement
1117+
}
1118+
None => {
1119+
return Err(Error::ClientError(format!(
1120+
"prepared statement `{}` not found",
1121+
name
1122+
)))
1123+
}
1124+
};
1125+
1126+
// Since it's already in the buffer, we don't need to prepare it on this server.
1127+
if will_prepare {
1128+
server.will_prepare(&statement.name);
1129+
will_prepare = false;
1130+
} else {
1131+
// The statement is not prepared on the server, so we need to prepare it.
1132+
if server.should_prepare(&statement.name) {
1133+
server.prepare(statement).await?;
1134+
}
1135+
}
1136+
1137+
// Done processing the prepared statement.
1138+
prepared_statement = None;
1139+
}
1140+
1141+
let mut message = match initial_message {
10581142
None => {
10591143
trace!("Waiting for message inside transaction or in session mode");
10601144

@@ -1173,6 +1257,11 @@ where
11731257
// Parse
11741258
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
11751259
'P' => {
1260+
if prepared_statements_enabled {
1261+
(prepared_statement, message) = self.rewrite_parse(message)?;
1262+
will_prepare = true;
1263+
}
1264+
11761265
if query_router.query_parser_enabled() {
11771266
if let Ok(ast) = QueryRouter::parse(&message) {
11781267
if let Ok(output) = query_router.execute_plugins(&ast).await {
@@ -1187,12 +1276,25 @@ where
11871276
// Bind
11881277
// The placeholder's replacements are here, e.g. '[email protected]' and 'true'
11891278
'B' => {
1279+
if prepared_statements_enabled {
1280+
(prepared_statement, message) = self.rewrite_bind(message).await?;
1281+
}
1282+
11901283
self.buffer.put(&message[..]);
11911284
}
11921285

11931286
// Describe
11941287
// Command a client can issue to describe a previously prepared named statement.
11951288
'D' => {
1289+
if prepared_statements_enabled {
1290+
let name;
1291+
(name, message) = self.rewrite_describe(message).await?;
1292+
1293+
if let Some(name) = name {
1294+
prepared_statement = Some(name);
1295+
}
1296+
}
1297+
11961298
self.buffer.put(&message[..]);
11971299
}
11981300

@@ -1235,7 +1337,7 @@ where
12351337
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
12361338

12371339
// Almost certainly true
1238-
if first_message_code == 'P' {
1340+
if first_message_code == 'P' && !prepared_statements_enabled {
12391341
// Message layout
12401342
// P followed by 32 int followed by null-terminated statement name
12411343
// So message code should be in offset 0 of the buffer, first character
@@ -1363,6 +1465,107 @@ where
13631465
}
13641466
}
13651467

1468+
/// Rewrite Parse (F) message to set the prepared statement name to one we control.
1469+
/// Save it into the client cache.
1470+
fn rewrite_parse(&mut self, message: BytesMut) -> Result<(Option<String>, BytesMut), Error> {
1471+
let parse: Parse = (&message).try_into()?;
1472+
1473+
let name = parse.name.clone();
1474+
1475+
// Don't rewrite anonymous prepared statements
1476+
if parse.anonymous() {
1477+
debug!("Anonymous prepared statement");
1478+
return Ok((None, message));
1479+
}
1480+
1481+
let parse = parse.rename();
1482+
1483+
debug!(
1484+
"Renamed prepared statement `{}` to `{}` and saved to cache",
1485+
name, parse.name
1486+
);
1487+
1488+
self.prepared_statements.insert(name.clone(), parse.clone());
1489+
1490+
Ok((Some(name), parse.try_into()?))
1491+
}
1492+
1493+
/// Rewrite the Bind (F) message to use the prepared statement name
1494+
/// saved in the client cache.
1495+
async fn rewrite_bind(
1496+
&mut self,
1497+
message: BytesMut,
1498+
) -> Result<(Option<String>, BytesMut), Error> {
1499+
let bind: Bind = (&message).try_into()?;
1500+
let name = bind.prepared_statement.clone();
1501+
1502+
if bind.anonymous() {
1503+
debug!("Anonymous bind message");
1504+
return Ok((None, message));
1505+
}
1506+
1507+
match self.prepared_statements.get(&name) {
1508+
Some(prepared_stmt) => {
1509+
let bind = bind.reassign(prepared_stmt);
1510+
1511+
debug!("Rewrote bind `{}` to `{}`", name, bind.prepared_statement);
1512+
1513+
Ok((Some(name), bind.try_into()?))
1514+
}
1515+
None => {
1516+
debug!("Got bind for unknown prepared statement {:?}", bind);
1517+
1518+
error_response(
1519+
&mut self.write,
1520+
&format!(
1521+
"prepared statement \"{}\" does not exist",
1522+
bind.prepared_statement
1523+
),
1524+
)
1525+
.await?;
1526+
1527+
Err(Error::ClientError(format!(
1528+
"Prepared statement `{}` doesn't exist",
1529+
name
1530+
)))
1531+
}
1532+
}
1533+
}
1534+
1535+
/// Rewrite the Describe (F) message to use the prepared statement name
1536+
/// saved in the client cache.
1537+
async fn rewrite_describe(
1538+
&mut self,
1539+
message: BytesMut,
1540+
) -> Result<(Option<String>, BytesMut), Error> {
1541+
let describe: Describe = (&message).try_into()?;
1542+
let name = describe.statement_name.clone();
1543+
1544+
if describe.anonymous() {
1545+
debug!("Anonymous describe");
1546+
return Ok((None, message));
1547+
}
1548+
1549+
match self.prepared_statements.get(&name) {
1550+
Some(prepared_stmt) => {
1551+
let describe = describe.rename(&prepared_stmt.name);
1552+
1553+
debug!(
1554+
"Rewrote describe `{}` to `{}`",
1555+
name, describe.statement_name
1556+
);
1557+
1558+
Ok((Some(name), describe.try_into()?))
1559+
}
1560+
1561+
None => {
1562+
debug!("Got describe for unknown prepared statement {:?}", describe);
1563+
1564+
Ok((None, message))
1565+
}
1566+
}
1567+
}
1568+
13661569
/// Release the server from the client: it can't cancel its queries anymore.
13671570
pub fn release(&self) {
13681571
let mut guard = self.client_server_map.lock();

0 commit comments

Comments
 (0)