Skip to content

Commit 767e112

Browse files
committed
Add SET configuration_parameter support for PostgreSQL functions.
1 parent 644c57c commit 767e112

File tree

6 files changed

+113
-7
lines changed

6 files changed

+113
-7
lines changed

src/ast/ddl.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@ use crate::ast::{
4343
},
4444
ArgMode, AttachedToken, CommentDef, ConditionalStatements, CreateFunctionBody,
4545
CreateFunctionUsing, CreateTableLikeKind, CreateTableOptions, CreateViewParams, DataType, Expr,
46-
FileFormat, FunctionBehavior, FunctionCalledOnNull, FunctionDesc, FunctionDeterminismSpecifier,
47-
FunctionParallel, FunctionSecurity, HiveDistributionStyle, HiveFormat, HiveIOFormat,
48-
HiveRowFormat, HiveSetLocation, Ident, InitializeKind, MySQLColumnPosition, ObjectName,
49-
OnCommit, OneOrManyWithParens, OperateFunctionArg, OrderByExpr, ProjectionSelect, Query,
50-
RefreshModeKind, RowAccessPolicy, SequenceOptions, Spanned, SqlOption,
51-
StorageSerializationPolicy, TableVersion, Tag, TriggerEvent, TriggerExecBody, TriggerObject,
52-
TriggerPeriod, TriggerReferencing, Value, ValueWithSpan, WrappedCollection,
46+
FileFormat, FunctionBehavior, FunctionCalledOnNull, FunctionDefinitionSetParam, FunctionDesc,
47+
FunctionDeterminismSpecifier, FunctionParallel, FunctionSecurity, HiveDistributionStyle,
48+
HiveFormat, HiveIOFormat, HiveRowFormat, HiveSetLocation, Ident, InitializeKind,
49+
MySQLColumnPosition, ObjectName, OnCommit, OneOrManyWithParens, OperateFunctionArg,
50+
OrderByExpr, ProjectionSelect, Query, RefreshModeKind, RowAccessPolicy, SequenceOptions,
51+
Spanned, SqlOption, StorageSerializationPolicy, TableVersion, Tag, TriggerEvent,
52+
TriggerExecBody, TriggerObject, TriggerPeriod, TriggerReferencing, Value, ValueWithSpan,
53+
WrappedCollection,
5354
};
5455
use crate::display_utils::{DisplayCommaSeparated, Indent, NewLine, SpaceOrNewline};
5556
use crate::keywords::Keyword;
@@ -3230,6 +3231,10 @@ pub struct CreateFunction {
32303231
///
32313232
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
32323233
pub security: Option<FunctionSecurity>,
3234+
/// SET configuration_parameter clauses
3235+
///
3236+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
3237+
pub set_params: Vec<FunctionDefinitionSetParam>,
32333238
/// USING ... (Hive only)
32343239
pub using: Option<CreateFunctionUsing>,
32353240
/// Language used in a UDF definition.
@@ -3299,6 +3304,9 @@ impl fmt::Display for CreateFunction {
32993304
if let Some(security) = &self.security {
33003305
write!(f, " {security}")?;
33013306
}
3307+
for set_param in &self.set_params {
3308+
write!(f, " {set_param}")?;
3309+
}
33023310
if let Some(remote_connection) = &self.remote_connection {
33033311
write!(f, " REMOTE WITH CONNECTION {remote_connection}")?;
33043312
}

src/ast/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8801,6 +8801,42 @@ impl fmt::Display for FunctionSecurity {
88018801
}
88028802
}
88038803

8804+
/// Value for a SET configuration parameter in a CREATE FUNCTION statement.
8805+
///
8806+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
8807+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
8808+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8809+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
8810+
pub enum FunctionSetValue {
8811+
/// SET param = value1, value2, ...
8812+
Values(Vec<Expr>),
8813+
/// SET param FROM CURRENT
8814+
FromCurrent,
8815+
}
8816+
8817+
/// A SET configuration_parameter clause in a CREATE FUNCTION statement.
8818+
///
8819+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
8820+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
8821+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8822+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
8823+
pub struct FunctionDefinitionSetParam {
8824+
pub name: Ident,
8825+
pub value: FunctionSetValue,
8826+
}
8827+
8828+
impl fmt::Display for FunctionDefinitionSetParam {
8829+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
8830+
write!(f, "SET {} ", self.name)?;
8831+
match &self.value {
8832+
FunctionSetValue::Values(values) => {
8833+
write!(f, "= {}", display_comma_separated(values))
8834+
}
8835+
FunctionSetValue::FromCurrent => write!(f, "FROM CURRENT"),
8836+
}
8837+
}
8838+
}
8839+
88048840
/// These attributes describe the behavior of the function when called with a null argument.
88058841
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
88068842
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/parser/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5263,6 +5263,7 @@ impl<'a> Parser<'a> {
52635263
security: Option<FunctionSecurity>,
52645264
}
52655265
let mut body = Body::default();
5266+
let mut set_params: Vec<FunctionDefinitionSetParam> = Vec::new();
52665267
loop {
52675268
fn ensure_not_set<T>(field: &Option<T>, name: &str) -> Result<(), ParserError> {
52685269
if field.is_some() {
@@ -5336,6 +5337,18 @@ impl<'a> Parser<'a> {
53365337
} else {
53375338
return self.expected("DEFINER or INVOKER", self.peek_token());
53385339
}
5340+
} else if self.parse_keyword(Keyword::SET) {
5341+
let name = self.parse_identifier()?;
5342+
let value = if self.parse_keywords(&[Keyword::FROM, Keyword::CURRENT]) {
5343+
FunctionSetValue::FromCurrent
5344+
} else {
5345+
if !self.consume_token(&Token::Eq) && !self.parse_keyword(Keyword::TO) {
5346+
return self.expected("= or TO", self.peek_token());
5347+
}
5348+
let values = self.parse_comma_separated(Parser::parse_expr)?;
5349+
FunctionSetValue::Values(values)
5350+
};
5351+
set_params.push(FunctionDefinitionSetParam { name, value });
53395352
} else if self.parse_keyword(Keyword::RETURN) {
53405353
ensure_not_set(&body.function_body, "RETURN")?;
53415354
body.function_body = Some(CreateFunctionBody::Return(self.parse_expr()?));
@@ -5355,6 +5368,7 @@ impl<'a> Parser<'a> {
53555368
called_on_null: body.called_on_null,
53565369
parallel: body.parallel,
53575370
security: body.security,
5371+
set_params,
53585372
language: body.language,
53595373
function_body: body.function_body,
53605374
if_not_exists: false,
@@ -5393,6 +5407,7 @@ impl<'a> Parser<'a> {
53935407
called_on_null: None,
53945408
parallel: None,
53955409
security: None,
5410+
set_params: vec![],
53965411
language: None,
53975412
determinism_specifier: None,
53985413
options: None,
@@ -5476,6 +5491,7 @@ impl<'a> Parser<'a> {
54765491
called_on_null: None,
54775492
parallel: None,
54785493
security: None,
5494+
set_params: vec![],
54795495
}))
54805496
}
54815497

@@ -5566,6 +5582,7 @@ impl<'a> Parser<'a> {
55665582
called_on_null: None,
55675583
parallel: None,
55685584
security: None,
5585+
set_params: vec![],
55695586
}))
55705587
}
55715588

tests/sqlparser_bigquery.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,6 +2295,7 @@ fn test_bigquery_create_function() {
22952295
called_on_null: None,
22962296
parallel: None,
22972297
security: None,
2298+
set_params: vec![],
22982299
})
22992300
);
23002301

tests/sqlparser_mssql.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ fn parse_create_function() {
267267
called_on_null: None,
268268
parallel: None,
269269
security: None,
270+
set_params: vec![],
270271
using: None,
271272
language: None,
272273
determinism_specifier: None,
@@ -441,6 +442,7 @@ fn parse_create_function_parameter_default_values() {
441442
called_on_null: None,
442443
parallel: None,
443444
security: None,
445+
set_params: vec![],
444446
using: None,
445447
language: None,
446448
determinism_specifier: None,

tests/sqlparser_postgres.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4295,6 +4295,7 @@ $$"#;
42954295
called_on_null: None,
42964296
parallel: None,
42974297
security: None,
4298+
set_params: vec![],
42984299
function_body: Some(CreateFunctionBody::AsBeforeOptions {
42994300
body: Expr::Value(
43004301
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF str1 <> str2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4337,6 +4338,7 @@ $$"#;
43374338
called_on_null: None,
43384339
parallel: None,
43394340
security: None,
4341+
set_params: vec![],
43404342
function_body: Some(CreateFunctionBody::AsBeforeOptions {
43414343
body: Expr::Value(
43424344
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> 0 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4383,6 +4385,7 @@ $$"#;
43834385
called_on_null: None,
43844386
parallel: None,
43854387
security: None,
4388+
set_params: vec![],
43864389
function_body: Some(CreateFunctionBody::AsBeforeOptions {
43874390
body: Expr::Value(
43884391
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF a <> b THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4429,6 +4432,7 @@ $$"#;
44294432
called_on_null: None,
44304433
parallel: None,
44314434
security: None,
4435+
set_params: vec![],
44324436
function_body: Some(CreateFunctionBody::AsBeforeOptions {
44334437
body: Expr::Value(
44344438
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> int2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4468,6 +4472,7 @@ $$"#;
44684472
called_on_null: None,
44694473
parallel: None,
44704474
security: None,
4475+
set_params: vec![],
44714476
function_body: Some(CreateFunctionBody::AsBeforeOptions {
44724477
body: Expr::Value(
44734478
(Value::DollarQuotedString(DollarQuotedString {
@@ -4510,6 +4515,7 @@ fn parse_create_function() {
45104515
called_on_null: Some(FunctionCalledOnNull::Strict),
45114516
parallel: Some(FunctionParallel::Safe),
45124517
security: None,
4518+
set_params: vec![],
45134519
function_body: Some(CreateFunctionBody::AsBeforeOptions {
45144520
body: Expr::Value(
45154521
(Value::SingleQuotedString("select $1 + $2;".into())).with_empty_span()
@@ -4561,6 +4567,40 @@ fn parse_create_function_with_security() {
45614567
}
45624568
}
45634569

4570+
#[test]
4571+
fn parse_create_function_with_set_params() {
4572+
let sql =
4573+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SET search_path = auth, pg_temp, public AS $$ SELECT 1 $$";
4574+
match pg_and_generic().verified_stmt(sql) {
4575+
Statement::CreateFunction(CreateFunction { set_params, .. }) => {
4576+
assert_eq!(set_params.len(), 1);
4577+
assert_eq!(set_params[0].name.to_string(), "search_path");
4578+
}
4579+
_ => panic!("Expected CreateFunction"),
4580+
}
4581+
4582+
// Test multiple SET params
4583+
let sql2 =
4584+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SET search_path = public SET statement_timeout = '5s' AS $$ SELECT 1 $$";
4585+
match pg_and_generic().verified_stmt(sql2) {
4586+
Statement::CreateFunction(CreateFunction { set_params, .. }) => {
4587+
assert_eq!(set_params.len(), 2);
4588+
}
4589+
_ => panic!("Expected CreateFunction"),
4590+
}
4591+
4592+
// Test FROM CURRENT
4593+
let sql3 =
4594+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SET search_path FROM CURRENT AS $$ SELECT 1 $$";
4595+
match pg_and_generic().verified_stmt(sql3) {
4596+
Statement::CreateFunction(CreateFunction { set_params, .. }) => {
4597+
assert_eq!(set_params.len(), 1);
4598+
assert!(matches!(set_params[0].value, FunctionSetValue::FromCurrent));
4599+
}
4600+
_ => panic!("Expected CreateFunction"),
4601+
}
4602+
}
4603+
45644604
#[test]
45654605
fn parse_incorrect_create_function_parallel() {
45664606
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'";
@@ -4590,6 +4630,7 @@ fn parse_create_function_c_with_module_pathname() {
45904630
called_on_null: None,
45914631
parallel: Some(FunctionParallel::Safe),
45924632
security: None,
4633+
set_params: vec![],
45934634
function_body: Some(CreateFunctionBody::AsBeforeOptions {
45944635
body: Expr::Value(
45954636
(Value::SingleQuotedString("MODULE_PATHNAME".into())).with_empty_span()
@@ -6216,6 +6257,7 @@ fn parse_trigger_related_functions() {
62166257
called_on_null: None,
62176258
parallel: None,
62186259
security: None,
6260+
set_params: vec![],
62196261
using: None,
62206262
language: Some(Ident::new("plpgsql")),
62216263
determinism_specifier: None,

0 commit comments

Comments
 (0)