Skip to content

Commit 05d7ffb

Browse files
Handle optional datatypes properly in CREATE FUNCTION statements (#1826)
Co-authored-by: Ifeanyi Ubah <[email protected]>
1 parent 3f4d5f9 commit 05d7ffb

File tree

2 files changed

+225
-5
lines changed

2 files changed

+225
-5
lines changed

src/parser/mod.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5273,12 +5273,21 @@ impl<'a> Parser<'a> {
52735273
// parse: [ argname ] argtype
52745274
let mut name = None;
52755275
let mut data_type = self.parse_data_type()?;
5276-
if let DataType::Custom(n, _) = &data_type {
5277-
// the first token is actually a name
5278-
match n.0[0].clone() {
5279-
ObjectNamePart::Identifier(ident) => name = Some(ident),
5276+
5277+
// To check whether the first token is a name or a type, we need to
5278+
// peek the next token, which if it is another type keyword, then the
5279+
// first token is a name and not a type in itself.
5280+
let data_type_idx = self.get_current_index();
5281+
if let Some(next_data_type) = self.maybe_parse(|parser| parser.parse_data_type())? {
5282+
let token = self.token_at(data_type_idx);
5283+
5284+
// We ensure that the token is a `Word` token, and not other special tokens.
5285+
if !matches!(token.token, Token::Word(_)) {
5286+
return self.expected("a name or type", token.clone());
52805287
}
5281-
data_type = self.parse_data_type()?;
5288+
5289+
name = Some(Ident::new(token.to_string()));
5290+
data_type = next_data_type;
52825291
}
52835292

52845293
let default_expr = if self.parse_keyword(Keyword::DEFAULT) || self.consume_token(&Token::Eq)

tests/sqlparser_postgres.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
2222
#[macro_use]
2323
mod test_utils;
24+
2425
use helpers::attached_token::AttachedToken;
2526
use sqlparser::tokenizer::Span;
2627
use test_utils::*;
@@ -4105,6 +4106,216 @@ fn parse_update_in_with_subquery() {
41054106
pg_and_generic().verified_stmt(r#"WITH "result" AS (UPDATE "Hero" SET "name" = 'Captain America', "number_of_movies" = "number_of_movies" + 1 WHERE "secret_identity" = 'Sam Wilson' RETURNING "id", "name", "secret_identity", "number_of_movies") SELECT * FROM "result""#);
41064107
}
41074108

4109+
#[test]
4110+
fn parser_create_function_with_args() {
4111+
let sql1 = r#"CREATE OR REPLACE FUNCTION check_strings_different(str1 VARCHAR, str2 VARCHAR) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4112+
BEGIN
4113+
IF str1 <> str2 THEN
4114+
RETURN TRUE;
4115+
ELSE
4116+
RETURN FALSE;
4117+
END IF;
4118+
END;
4119+
$$"#;
4120+
4121+
assert_eq!(
4122+
pg_and_generic().verified_stmt(sql1),
4123+
Statement::CreateFunction(CreateFunction {
4124+
or_alter: false,
4125+
or_replace: true,
4126+
temporary: false,
4127+
name: ObjectName::from(vec![Ident::new("check_strings_different")]),
4128+
args: Some(vec![
4129+
OperateFunctionArg::with_name(
4130+
"str1",
4131+
DataType::Varchar(None),
4132+
),
4133+
OperateFunctionArg::with_name(
4134+
"str2",
4135+
DataType::Varchar(None),
4136+
),
4137+
]),
4138+
return_type: Some(DataType::Boolean),
4139+
language: Some("plpgsql".into()),
4140+
behavior: None,
4141+
called_on_null: None,
4142+
parallel: None,
4143+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4144+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF str1 <> str2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4145+
))),
4146+
if_not_exists: false,
4147+
using: None,
4148+
determinism_specifier: None,
4149+
options: None,
4150+
remote_connection: None,
4151+
})
4152+
);
4153+
4154+
let sql2 = r#"CREATE OR REPLACE FUNCTION check_not_zero(int1 INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4155+
BEGIN
4156+
IF int1 <> 0 THEN
4157+
RETURN TRUE;
4158+
ELSE
4159+
RETURN FALSE;
4160+
END IF;
4161+
END;
4162+
$$"#;
4163+
assert_eq!(
4164+
pg_and_generic().verified_stmt(sql2),
4165+
Statement::CreateFunction(CreateFunction {
4166+
or_alter: false,
4167+
or_replace: true,
4168+
temporary: false,
4169+
name: ObjectName::from(vec![Ident::new("check_not_zero")]),
4170+
args: Some(vec![
4171+
OperateFunctionArg::with_name(
4172+
"int1",
4173+
DataType::Int(None)
4174+
)
4175+
]),
4176+
return_type: Some(DataType::Boolean),
4177+
language: Some("plpgsql".into()),
4178+
behavior: None,
4179+
called_on_null: None,
4180+
parallel: None,
4181+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4182+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> 0 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4183+
))),
4184+
if_not_exists: false,
4185+
using: None,
4186+
determinism_specifier: None,
4187+
options: None,
4188+
remote_connection: None,
4189+
})
4190+
);
4191+
4192+
let sql3 = r#"CREATE OR REPLACE FUNCTION check_values_different(a INT, b INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4193+
BEGIN
4194+
IF a <> b THEN
4195+
RETURN TRUE;
4196+
ELSE
4197+
RETURN FALSE;
4198+
END IF;
4199+
END;
4200+
$$"#;
4201+
assert_eq!(
4202+
pg_and_generic().verified_stmt(sql3),
4203+
Statement::CreateFunction(CreateFunction {
4204+
or_alter: false,
4205+
or_replace: true,
4206+
temporary: false,
4207+
name: ObjectName::from(vec![Ident::new("check_values_different")]),
4208+
args: Some(vec![
4209+
OperateFunctionArg::with_name(
4210+
"a",
4211+
DataType::Int(None)
4212+
),
4213+
OperateFunctionArg::with_name(
4214+
"b",
4215+
DataType::Int(None)
4216+
),
4217+
]),
4218+
return_type: Some(DataType::Boolean),
4219+
language: Some("plpgsql".into()),
4220+
behavior: None,
4221+
called_on_null: None,
4222+
parallel: None,
4223+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4224+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF a <> b THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4225+
))),
4226+
if_not_exists: false,
4227+
using: None,
4228+
determinism_specifier: None,
4229+
options: None,
4230+
remote_connection: None,
4231+
})
4232+
);
4233+
4234+
let sql4 = r#"CREATE OR REPLACE FUNCTION check_values_different(int1 INT, int2 INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4235+
BEGIN
4236+
IF int1 <> int2 THEN
4237+
RETURN TRUE;
4238+
ELSE
4239+
RETURN FALSE;
4240+
END IF;
4241+
END;
4242+
$$"#;
4243+
assert_eq!(
4244+
pg_and_generic().verified_stmt(sql4),
4245+
Statement::CreateFunction(CreateFunction {
4246+
or_alter: false,
4247+
or_replace: true,
4248+
temporary: false,
4249+
name: ObjectName::from(vec![Ident::new("check_values_different")]),
4250+
args: Some(vec![
4251+
OperateFunctionArg::with_name(
4252+
"int1",
4253+
DataType::Int(None)
4254+
),
4255+
OperateFunctionArg::with_name(
4256+
"int2",
4257+
DataType::Int(None)
4258+
),
4259+
]),
4260+
return_type: Some(DataType::Boolean),
4261+
language: Some("plpgsql".into()),
4262+
behavior: None,
4263+
called_on_null: None,
4264+
parallel: None,
4265+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4266+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> int2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4267+
))),
4268+
if_not_exists: false,
4269+
using: None,
4270+
determinism_specifier: None,
4271+
options: None,
4272+
remote_connection: None,
4273+
})
4274+
);
4275+
4276+
let sql5 = r#"CREATE OR REPLACE FUNCTION foo(a TIMESTAMP WITH TIME ZONE, b VARCHAR) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4277+
BEGIN
4278+
RETURN TRUE;
4279+
END;
4280+
$$"#;
4281+
assert_eq!(
4282+
pg_and_generic().verified_stmt(sql5),
4283+
Statement::CreateFunction(CreateFunction {
4284+
or_alter: false,
4285+
or_replace: true,
4286+
temporary: false,
4287+
name: ObjectName::from(vec![Ident::new("foo")]),
4288+
args: Some(vec![
4289+
OperateFunctionArg::with_name(
4290+
"a",
4291+
DataType::Timestamp(None, TimezoneInfo::WithTimeZone)
4292+
),
4293+
OperateFunctionArg::with_name("b", DataType::Varchar(None)),
4294+
]),
4295+
return_type: Some(DataType::Boolean),
4296+
language: Some("plpgsql".into()),
4297+
behavior: None,
4298+
called_on_null: None,
4299+
parallel: None,
4300+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4301+
(Value::DollarQuotedString(DollarQuotedString {
4302+
value: "\n BEGIN\n RETURN TRUE;\n END;\n ".to_owned(),
4303+
tag: None
4304+
}))
4305+
.with_empty_span()
4306+
))),
4307+
if_not_exists: false,
4308+
using: None,
4309+
determinism_specifier: None,
4310+
options: None,
4311+
remote_connection: None,
4312+
})
4313+
);
4314+
4315+
let incorrect_sql = "CREATE FUNCTION add(function(struct<a,b> int64), b INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";
4316+
assert!(pg().parse_sql_statements(incorrect_sql).is_err(),);
4317+
}
4318+
41084319
#[test]
41094320
fn parse_create_function() {
41104321
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";

0 commit comments

Comments
 (0)