@@ -3,8 +3,9 @@ use crate::pool::BanReason;
3
3
/// Handle clients by pretending to be a PostgreSQL server.
4
4
use bytes:: { Buf , BufMut , BytesMut } ;
5
5
use log:: { debug, error, info, trace, warn} ;
6
+ use once_cell:: sync:: Lazy ;
6
7
use std:: collections:: HashMap ;
7
- use std:: sync:: Arc ;
8
+ use std:: sync:: { atomic :: AtomicUsize , Arc } ;
8
9
use std:: time:: Instant ;
9
10
use tokio:: io:: { split, AsyncReadExt , BufReader , ReadHalf , WriteHalf } ;
10
11
use tokio:: net:: TcpStream ;
@@ -13,7 +14,9 @@ use tokio::sync::mpsc::Sender;
13
14
14
15
use crate :: admin:: { generate_server_info_for_admin, handle_admin} ;
15
16
use crate :: auth_passthrough:: refetch_auth_hash;
16
- use crate :: config:: { get_config, get_idle_client_in_transaction_timeout, Address , PoolMode } ;
17
+ use crate :: config:: {
18
+ get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address , PoolMode ,
19
+ } ;
17
20
use crate :: constants:: * ;
18
21
use crate :: messages:: * ;
19
22
use crate :: plugins:: PluginOutput ;
@@ -25,6 +28,11 @@ use crate::tls::Tls;
25
28
26
29
use tokio_rustls:: server:: TlsStream ;
27
30
31
+ /// Incrementally count prepared statements
32
+ /// to avoid random conflicts in places where the random number generator is weak.
33
+ pub static PREPARED_STATEMENT_COUNTER : Lazy < Arc < AtomicUsize > > =
34
+ Lazy :: new ( || Arc :: new ( AtomicUsize :: new ( 0 ) ) ) ;
35
+
28
36
/// Type of connection received from client.
29
37
enum ClientConnectionType {
30
38
Startup ,
@@ -93,6 +101,9 @@ pub struct Client<S, T> {
93
101
94
102
/// Used to notify clients about an impending shutdown
95
103
shutdown : Receiver < ( ) > ,
104
+
105
+ /// Prepared statements
106
+ prepared_statements : HashMap < String , Parse > ,
96
107
}
97
108
98
109
/// Client entrypoint.
@@ -682,6 +693,7 @@ where
682
693
application_name : application_name. to_string ( ) ,
683
694
shutdown,
684
695
connected_to_server : false ,
696
+ prepared_statements : HashMap :: new ( ) ,
685
697
} )
686
698
}
687
699
@@ -716,6 +728,7 @@ where
716
728
application_name : String :: from ( "undefined" ) ,
717
729
shutdown,
718
730
connected_to_server : false ,
731
+ prepared_statements : HashMap :: new ( ) ,
719
732
} )
720
733
}
721
734
@@ -757,6 +770,10 @@ where
757
770
// Result returned by one of the plugins.
758
771
let mut plugin_output = None ;
759
772
773
+ // Prepared statement being executed
774
+ let mut prepared_statement = None ;
775
+ let mut will_prepare = false ;
776
+
760
777
// Our custom protocol loop.
761
778
// We expect the client to either start a transaction with regular queries
762
779
// or issue commands for our sharding and server selection protocol.
@@ -766,13 +783,16 @@ where
766
783
self . transaction_mode
767
784
) ;
768
785
786
+ // Should we rewrite prepared statements and bind messages?
787
+ let mut prepared_statements_enabled = get_prepared_statements ( ) ;
788
+
769
789
// Read a complete message from the client, which normally would be
770
790
// either a `Q` (query) or `P` (prepare, extended protocol).
771
791
// We can parse it here before grabbing a server from the pool,
772
792
// in case the client is sending some custom protocol messages, e.g.
773
793
// SET SHARDING KEY TO 'bigint';
774
794
775
- let message = tokio:: select! {
795
+ let mut message = tokio:: select! {
776
796
_ = self . shutdown. recv( ) => {
777
797
if !self . admin {
778
798
error_response_terminal(
@@ -800,7 +820,21 @@ where
800
820
// allocate a connection, we wouldn't be able to send back an error message
801
821
// to the client so we buffer them and defer the decision to error out or not
802
822
// to when we get the S message
803
- 'D' | 'E' => {
823
+ 'D' => {
824
+ if prepared_statements_enabled {
825
+ let name;
826
+ ( name, message) = self . rewrite_describe ( message) . await ?;
827
+
828
+ if let Some ( name) = name {
829
+ prepared_statement = Some ( name) ;
830
+ }
831
+ }
832
+
833
+ self . buffer . put ( & message[ ..] ) ;
834
+ continue ;
835
+ }
836
+
837
+ 'E' => {
804
838
self . buffer . put ( & message[ ..] ) ;
805
839
continue ;
806
840
}
@@ -830,6 +864,11 @@ where
830
864
}
831
865
832
866
'P' => {
867
+ if prepared_statements_enabled {
868
+ ( prepared_statement, message) = self . rewrite_parse ( message) ?;
869
+ will_prepare = true ;
870
+ }
871
+
833
872
self . buffer . put ( & message[ ..] ) ;
834
873
835
874
if query_router. query_parser_enabled ( ) {
@@ -846,6 +885,10 @@ where
846
885
}
847
886
848
887
'B' => {
888
+ if prepared_statements_enabled {
889
+ ( prepared_statement, message) = self . rewrite_bind ( message) . await ?;
890
+ }
891
+
849
892
self . buffer . put ( & message[ ..] ) ;
850
893
851
894
if query_router. query_parser_enabled ( ) {
@@ -1054,7 +1097,48 @@ where
1054
1097
// If the client is in session mode, no more custom protocol
1055
1098
// commands will be accepted.
1056
1099
loop {
1057
- let message = match initial_message {
1100
+ // Only check if we should rewrite prepared statements
1101
+ // in session mode. In transaction mode, we check at the beginning of
1102
+ // each transaction.
1103
+ if !self . transaction_mode {
1104
+ prepared_statements_enabled = get_prepared_statements ( ) ;
1105
+ }
1106
+
1107
+ debug ! ( "Prepared statement active: {:?}" , prepared_statement) ;
1108
+
1109
+ // We are processing a prepared statement.
1110
+ if let Some ( ref name) = prepared_statement {
1111
+ debug ! ( "Checking prepared statement is on server" ) ;
1112
+ // Get the prepared statement the server expects to see.
1113
+ let statement = match self . prepared_statements . get ( name) {
1114
+ Some ( statement) => {
1115
+ debug ! ( "Prepared statement `{}` found in cache" , name) ;
1116
+ statement
1117
+ }
1118
+ None => {
1119
+ return Err ( Error :: ClientError ( format ! (
1120
+ "prepared statement `{}` not found" ,
1121
+ name
1122
+ ) ) )
1123
+ }
1124
+ } ;
1125
+
1126
+ // Since it's already in the buffer, we don't need to prepare it on this server.
1127
+ if will_prepare {
1128
+ server. will_prepare ( & statement. name ) ;
1129
+ will_prepare = false ;
1130
+ } else {
1131
+ // The statement is not prepared on the server, so we need to prepare it.
1132
+ if server. should_prepare ( & statement. name ) {
1133
+ server. prepare ( statement) . await ?;
1134
+ }
1135
+ }
1136
+
1137
+ // Done processing the prepared statement.
1138
+ prepared_statement = None ;
1139
+ }
1140
+
1141
+ let mut message = match initial_message {
1058
1142
None => {
1059
1143
trace ! ( "Waiting for message inside transaction or in session mode" ) ;
1060
1144
@@ -1173,6 +1257,11 @@ where
1173
1257
// Parse
1174
1258
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
1175
1259
'P' => {
1260
+ if prepared_statements_enabled {
1261
+ ( prepared_statement, message) = self . rewrite_parse ( message) ?;
1262
+ will_prepare = true ;
1263
+ }
1264
+
1176
1265
if query_router. query_parser_enabled ( ) {
1177
1266
if let Ok ( ast) = QueryRouter :: parse ( & message) {
1178
1267
if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
@@ -1187,12 +1276,25 @@ where
1187
1276
// Bind
1188
1277
// The placeholder's replacements are here, e.g. '[email protected] ' and 'true'
1189
1278
'B' => {
1279
+ if prepared_statements_enabled {
1280
+ ( prepared_statement, message) = self . rewrite_bind ( message) . await ?;
1281
+ }
1282
+
1190
1283
self . buffer . put ( & message[ ..] ) ;
1191
1284
}
1192
1285
1193
1286
// Describe
1194
1287
// Command a client can issue to describe a previously prepared named statement.
1195
1288
'D' => {
1289
+ if prepared_statements_enabled {
1290
+ let name;
1291
+ ( name, message) = self . rewrite_describe ( message) . await ?;
1292
+
1293
+ if let Some ( name) = name {
1294
+ prepared_statement = Some ( name) ;
1295
+ }
1296
+ }
1297
+
1196
1298
self . buffer . put ( & message[ ..] ) ;
1197
1299
}
1198
1300
@@ -1235,7 +1337,7 @@ where
1235
1337
let first_message_code = ( * self . buffer . get ( 0 ) . unwrap_or ( & 0 ) ) as char ;
1236
1338
1237
1339
// Almost certainly true
1238
- if first_message_code == 'P' {
1340
+ if first_message_code == 'P' && !prepared_statements_enabled {
1239
1341
// Message layout
1240
1342
// P followed by 32 int followed by null-terminated statement name
1241
1343
// So message code should be in offset 0 of the buffer, first character
@@ -1363,6 +1465,107 @@ where
1363
1465
}
1364
1466
}
1365
1467
1468
+ /// Rewrite Parse (F) message to set the prepared statement name to one we control.
1469
+ /// Save it into the client cache.
1470
+ fn rewrite_parse ( & mut self , message : BytesMut ) -> Result < ( Option < String > , BytesMut ) , Error > {
1471
+ let parse: Parse = ( & message) . try_into ( ) ?;
1472
+
1473
+ let name = parse. name . clone ( ) ;
1474
+
1475
+ // Don't rewrite anonymous prepared statements
1476
+ if parse. anonymous ( ) {
1477
+ debug ! ( "Anonymous prepared statement" ) ;
1478
+ return Ok ( ( None , message) ) ;
1479
+ }
1480
+
1481
+ let parse = parse. rename ( ) ;
1482
+
1483
+ debug ! (
1484
+ "Renamed prepared statement `{}` to `{}` and saved to cache" ,
1485
+ name, parse. name
1486
+ ) ;
1487
+
1488
+ self . prepared_statements . insert ( name. clone ( ) , parse. clone ( ) ) ;
1489
+
1490
+ Ok ( ( Some ( name) , parse. try_into ( ) ?) )
1491
+ }
1492
+
1493
+ /// Rewrite the Bind (F) message to use the prepared statement name
1494
+ /// saved in the client cache.
1495
+ async fn rewrite_bind (
1496
+ & mut self ,
1497
+ message : BytesMut ,
1498
+ ) -> Result < ( Option < String > , BytesMut ) , Error > {
1499
+ let bind: Bind = ( & message) . try_into ( ) ?;
1500
+ let name = bind. prepared_statement . clone ( ) ;
1501
+
1502
+ if bind. anonymous ( ) {
1503
+ debug ! ( "Anonymous bind message" ) ;
1504
+ return Ok ( ( None , message) ) ;
1505
+ }
1506
+
1507
+ match self . prepared_statements . get ( & name) {
1508
+ Some ( prepared_stmt) => {
1509
+ let bind = bind. reassign ( prepared_stmt) ;
1510
+
1511
+ debug ! ( "Rewrote bind `{}` to `{}`" , name, bind. prepared_statement) ;
1512
+
1513
+ Ok ( ( Some ( name) , bind. try_into ( ) ?) )
1514
+ }
1515
+ None => {
1516
+ debug ! ( "Got bind for unknown prepared statement {:?}" , bind) ;
1517
+
1518
+ error_response (
1519
+ & mut self . write ,
1520
+ & format ! (
1521
+ "prepared statement \" {}\" does not exist" ,
1522
+ bind. prepared_statement
1523
+ ) ,
1524
+ )
1525
+ . await ?;
1526
+
1527
+ Err ( Error :: ClientError ( format ! (
1528
+ "Prepared statement `{}` doesn't exist" ,
1529
+ name
1530
+ ) ) )
1531
+ }
1532
+ }
1533
+ }
1534
+
1535
+ /// Rewrite the Describe (F) message to use the prepared statement name
1536
+ /// saved in the client cache.
1537
+ async fn rewrite_describe (
1538
+ & mut self ,
1539
+ message : BytesMut ,
1540
+ ) -> Result < ( Option < String > , BytesMut ) , Error > {
1541
+ let describe: Describe = ( & message) . try_into ( ) ?;
1542
+ let name = describe. statement_name . clone ( ) ;
1543
+
1544
+ if describe. anonymous ( ) {
1545
+ debug ! ( "Anonymous describe" ) ;
1546
+ return Ok ( ( None , message) ) ;
1547
+ }
1548
+
1549
+ match self . prepared_statements . get ( & name) {
1550
+ Some ( prepared_stmt) => {
1551
+ let describe = describe. rename ( & prepared_stmt. name ) ;
1552
+
1553
+ debug ! (
1554
+ "Rewrote describe `{}` to `{}`" ,
1555
+ name, describe. statement_name
1556
+ ) ;
1557
+
1558
+ Ok ( ( Some ( name) , describe. try_into ( ) ?) )
1559
+ }
1560
+
1561
+ None => {
1562
+ debug ! ( "Got describe for unknown prepared statement {:?}" , describe) ;
1563
+
1564
+ Ok ( ( None , message) )
1565
+ }
1566
+ }
1567
+ }
1568
+
1366
1569
/// Release the server from the client: it can't cancel its queries anymore.
1367
1570
pub fn release ( & self ) {
1368
1571
let mut guard = self . client_server_map . lock ( ) ;
0 commit comments