Skip to content

Support for the PG wire protocol #2702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
367 changes: 353 additions & 14 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ paste = "1.0"
percent-encoding = "2.3"
petgraph = { version = "0.6.5", default-features = false }
pin-project-lite = "0.2.9"
pgwire = { version = "0.28.0", features = ["server-api"] }
postgres-types = "0.2.5"
pretty_assertions = { version = "1.4", features = ["unstable"] }
proc-macro2 = "1.0"
Expand All @@ -216,6 +217,7 @@ rand08 = { package = "rand", version = "0.8" }
rand = "0.9"
rayon = "1.8"
rayon-core = "1.11.0"
rcgen = { version = "0.13.1", features = ["pem", "x509-parser", "crypto", "ring"] }
regex = "1"
reqwest = { version = "0.12", features = ["stream", "json"] }
ron = "0.8"
Expand All @@ -224,6 +226,8 @@ rust_decimal = { version = "1.29.1", features = ["db-tokio-postgres"] }
rustc-demangle = "0.1.21"
rustc-hash = "2"
rustyline = { version = "12.0.0", features = [] }
rustls-pki-types = "1.11.0"
rustls = "0.23.26"
scoped-tls = "1.0.1"
scopeguard = "1.1.0"
second-stack = "0.3"
Expand Down Expand Up @@ -255,6 +259,7 @@ termcolor = "1.2.0"
thin-vec = "0.2.13"
thiserror = "1.0.37"
tokio = { version = "1.37", features = ["full"] }
tokio-rustls = "0.26.2"
tokio_metrics = { version = "0.4.0" }
tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4"] }
tokio-stream = "0.1.17"
Expand Down
2 changes: 1 addition & 1 deletion crates/client-api-messages/src/name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ pub enum SetDefaultDomainResult {
///
/// Must match the regex `^[a-z0-9]+(-[a-z0-9]+)*$`
#[derive(Clone, Debug, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)]
pub struct DatabaseName(String);
pub struct DatabaseName(pub String);

impl AsRef<str> for DatabaseName {
fn as_ref(&self) -> &str {
Expand Down
46 changes: 40 additions & 6 deletions crates/client-api/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ impl TokenClaims {
}

impl SpacetimeAuth {
pub fn from_claims(
ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized),
claims: SpacetimeIdentityClaims,
) -> axum::response::Result<Self> {
let claims = TokenClaims {
issuer: claims.issuer,
subject: claims.subject,
audience: claims.audience,
};

let creds = {
let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?;
SpacetimeCreds::from_signed_token(token)
};
let identity = claims.id();

Ok(Self {
creds,
identity,
subject: claims.subject,
issuer: claims.issuer,
})
}

/// Allocate a new identity, and mint a new token for it.
pub async fn alloc(ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized)) -> axum::response::Result<Self> {
// Generate claims with a random subject.
Expand Down Expand Up @@ -186,6 +210,8 @@ pub trait JwtAuthProvider: Sync + Send + TokenSigner {
///
/// The `/identity/public-key` route calls this method to return the public key to callers.
fn public_key_bytes(&self) -> &[u8];
/// Return the private key used to verify JWTs, as the bytes of a PEM private key file.
fn private_key_bytes(&self) -> &[u8];
}

pub struct JwtKeyAuthProvider<TV: TokenValidator + Send + Sync> {
Expand Down Expand Up @@ -222,6 +248,10 @@ impl<TV: TokenValidator + Send + Sync> TokenSigner for JwtKeyAuthProvider<TV> {
impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV> {
type TV = TV;

fn validator(&self) -> &Self::TV {
&self.validator
}

fn local_issuer(&self) -> &str {
&self.local_issuer
}
Expand All @@ -230,8 +260,8 @@ impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV
&self.keys.public_pem
}

fn validator(&self) -> &Self::TV {
&self.validator
fn private_key_bytes(&self) -> &[u8] {
&self.keys.private_pem
}
}

Expand Down Expand Up @@ -260,6 +290,13 @@ mod tests {
}
}

pub async fn validate_token<S: NodeDelegate>(
state: &S,
token: &str,
) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
state.jwt_auth_provider().validator().validate_token(token).await
}

pub struct SpacetimeAuthHeader {
auth: Option<SpacetimeAuth>,
}
Expand All @@ -272,10 +309,7 @@ impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for Space
return Ok(Self { auth: None });
};

let claims = state
.jwt_auth_provider()
.validator()
.validate_token(&creds.token)
let claims = validate_token(state, &creds.token)
.await
.map_err(AuthorizationRejection::Custom)?;

Expand Down
36 changes: 25 additions & 11 deletions crates/client-api/src/routes/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::auth::{
SpacetimeIdentityToken,
};
use crate::routes::subscribe::generate_random_connection_id;
use crate::util::{ByteStringBody, NameOrIdentity};
pub use crate::util::{ByteStringBody, NameOrIdentity};
use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate};
use axum::body::{Body, Bytes};
use axum::extract::{Path, Query, State};
Expand All @@ -25,10 +25,11 @@ use spacetimedb::host::ReducerOutcome;
use spacetimedb::host::UpdateDatabaseResult;
use spacetimedb::identity::Identity;
use spacetimedb::messages::control_db::{Database, HostType};
use spacetimedb_client_api_messages::http::SqlStmtResult;
use spacetimedb_client_api_messages::name::{self, DatabaseName, DomainName, PublishOp, PublishResult};
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::sats;
use spacetimedb_lib::{sats, ProductValue};

use super::subscribe::handle_websocket;

Expand Down Expand Up @@ -381,19 +382,19 @@ async fn worker_ctx_find_database(

#[derive(Deserialize)]
pub struct SqlParams {
name_or_identity: NameOrIdentity,
pub name_or_identity: NameOrIdentity,
}

#[derive(Deserialize)]
pub struct SqlQueryParams {}

pub async fn sql<S>(
State(worker_ctx): State<S>,
Path(SqlParams { name_or_identity }): Path<SqlParams>,
Query(SqlQueryParams {}): Query<SqlQueryParams>,
Extension(auth): Extension<SpacetimeAuth>,
body: String,
) -> axum::response::Result<impl IntoResponse>
pub async fn sql_direct<S>(
worker_ctx: S,
SqlParams { name_or_identity }: SqlParams,
_params: SqlQueryParams,
auth: SpacetimeAuth,
sql: String,
) -> axum::response::Result<Vec<SqlStmtResult<ProductValue>>>
where
S: NodeDelegate + ControlStateDelegate,
{
Expand All @@ -413,7 +414,20 @@ where
.await
.map_err(log_and_500)?
.ok_or(StatusCode::NOT_FOUND)?;
let json = host.exec_sql(auth, database, body).await?;
host.exec_sql(auth, database, sql).await
}

pub async fn sql<S>(
State(worker_ctx): State<S>,
Path(name_or_identity): Path<SqlParams>,
Query(params): Query<SqlQueryParams>,
Extension(auth): Extension<SpacetimeAuth>,
body: String,
) -> axum::response::Result<impl IntoResponse>
where
S: NodeDelegate + ControlStateDelegate,
{
let json = sql_direct(worker_ctx, name_or_identity, params, auth, body).await?;

let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros);

Expand Down
9 changes: 6 additions & 3 deletions crates/core/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct JwtKeys {
pub public: DecodingKey,
pub public_pem: Box<[u8]>,
pub private: EncodingKey,
pub private_pem: Box<[u8]>,
pub kid: Option<String>,
}

Expand All @@ -23,15 +24,17 @@ impl JwtKeys {
/// respectively.
///
/// The key files must be PEM encoded ECDSA P256 keys.
pub fn new(public_pem: impl Into<Box<[u8]>>, private_pem: &[u8]) -> anyhow::Result<Self> {
pub fn new(public_pem: impl Into<Box<[u8]>>, private_pem: impl Into<Box<[u8]>>) -> anyhow::Result<Self> {
let public_pem = public_pem.into();
let private_pem = private_pem.into();
let public = DecodingKey::from_ec_pem(&public_pem)?;
let private = EncodingKey::from_ec_pem(private_pem)?;
let private = EncodingKey::from_ec_pem(&private_pem)?;

Ok(Self {
public,
private,
public_pem,
private_pem,
kid: None,
})
}
Expand Down Expand Up @@ -75,7 +78,7 @@ pub struct EcKeyPair {
impl TryFrom<EcKeyPair> for JwtKeys {
type Error = anyhow::Error;
fn try_from(pair: EcKeyPair) -> anyhow::Result<Self> {
JwtKeys::new(pair.public_key_bytes, &pair.private_key_bytes)
JwtKeys::new(pair.public_key_bytes, pair.private_key_bytes)
}
}

Expand Down
6 changes: 6 additions & 0 deletions crates/standalone/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,22 @@ http.workspace = true
log.workspace = true
openssl.workspace = true
parse-size.workspace = true
pgwire.workspace = true
prometheus.workspace = true
scopeguard.workspace = true
serde.workspace = true
serde_json.workspace = true
sled.workspace = true
socket2.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tower-http.workspace = true
toml.workspace = true
tracing = { workspace = true, features = ["release_max_level_debug"] }
rustls-pki-types.workspace = true
rcgen.workspace = true
rustls.workspace = true

[target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = {workspace = true}
Expand Down
1 change: 1 addition & 0 deletions crates/standalone/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod control_db;
pub mod pg_server;
pub mod subcommands;
pub mod util;
pub mod version;
Expand Down
Loading
Loading