Skip to content

Commit da00612

Browse files
authored
feat: enable positionally inserting joins in a FROM clause (prisma#258)
1 parent ac3d9f2 commit da00612

File tree

9 files changed

+387
-7
lines changed

9 files changed

+387
-7
lines changed

src/ast/table.rs

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::{Column, Comparable, ConditionTree, DefaultValue, ExpressionKind, IndexDefinition};
1+
use super::{Column, Comparable, ConditionTree, DefaultValue, ExpressionKind, IndexDefinition, Join, JoinData};
22
use crate::{
33
ast::{Expression, Row, Select, Values},
44
error::{Error, ErrorKind},
@@ -19,6 +19,7 @@ pub trait Aliasable<'a> {
1919
/// Either an identifier or a nested query.
2020
pub enum TableType<'a> {
2121
Table(Cow<'a, str>),
22+
JoinedTable((Cow<'a, str>, Vec<Join<'a>>)),
2223
Query(Select<'a>),
2324
Values(Values<'a>),
2425
}
@@ -126,6 +127,194 @@ impl<'a> Table<'a> {
126127

127128
Ok(result)
128129
}
130+
131+
/// Adds a `LEFT JOIN` clause to the query, specifically for that table.
132+
/// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
133+
///
134+
/// ```rust
135+
/// # use quaint::{ast::*, visitor::{Visitor, Sqlite}};
136+
/// # fn main() -> Result<(), quaint::error::Error> {
137+
/// let join = "posts".alias("p").on(("p", "visible").equals(true));
138+
/// let joined_table = Table::from("users").left_join(join);
139+
/// let query = Select::from_table(joined_table).and_from("comments");
140+
/// let (sql, params) = Sqlite::build(query)?;
141+
///
142+
/// assert_eq!(
143+
/// "SELECT `users`.*, `comments`.* FROM \
144+
/// `users` LEFT JOIN `posts` AS `p` ON `p`.`visible` = ?, \
145+
/// `comments`",
146+
/// sql
147+
/// );
148+
///
149+
/// assert_eq!(
150+
/// vec![
151+
/// Value::from(true),
152+
/// ],
153+
/// params
154+
/// );
155+
/// # Ok(())
156+
/// # }
157+
/// ```
158+
pub fn left_join<J>(mut self, join: J) -> Self
159+
where
160+
J: Into<JoinData<'a>>,
161+
{
162+
match self.typ {
163+
TableType::Table(table_name) => {
164+
self.typ = TableType::JoinedTable((table_name, vec![Join::Left(join.into())]))
165+
}
166+
TableType::JoinedTable((_, ref mut joins)) => joins.push(Join::Left(join.into())),
167+
TableType::Query(_) => {
168+
panic!("You cannot left_join on a table of type Query")
169+
}
170+
TableType::Values(_) => {
171+
panic!("You cannot left_join on a table of type Values")
172+
}
173+
}
174+
175+
self
176+
}
177+
178+
/// Adds an `INNER JOIN` clause to the query, specifically for that table.
179+
/// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
180+
///
181+
/// ```rust
182+
/// # use quaint::{ast::*, visitor::{Visitor, Sqlite}};
183+
/// # fn main() -> Result<(), quaint::error::Error> {
184+
/// let join = "posts".alias("p").on(("p", "visible").equals(true));
185+
/// let joined_table = Table::from("users").inner_join(join);
186+
/// let query = Select::from_table(joined_table).and_from("comments");
187+
/// let (sql, params) = Sqlite::build(query)?;
188+
///
189+
/// assert_eq!(
190+
/// "SELECT `users`.*, `comments`.* FROM \
191+
/// `users` INNER JOIN `posts` AS `p` ON `p`.`visible` = ?, \
192+
/// `comments`",
193+
/// sql
194+
/// );
195+
///
196+
/// assert_eq!(
197+
/// vec![
198+
/// Value::from(true),
199+
/// ],
200+
/// params
201+
/// );
202+
/// # Ok(())
203+
/// # }
204+
/// ```
205+
pub fn inner_join<J>(mut self, join: J) -> Self
206+
where
207+
J: Into<JoinData<'a>>,
208+
{
209+
match self.typ {
210+
TableType::Table(table_name) => {
211+
self.typ = TableType::JoinedTable((table_name, vec![Join::Inner(join.into())]))
212+
}
213+
TableType::JoinedTable((_, ref mut joins)) => joins.push(Join::Inner(join.into())),
214+
TableType::Query(_) => {
215+
panic!("You cannot inner_join on a table of type Query")
216+
}
217+
TableType::Values(_) => {
218+
panic!("You cannot inner_join on a table of type Values")
219+
}
220+
}
221+
222+
self
223+
}
224+
225+
/// Adds a `RIGHT JOIN` clause to the query, specifically for that table.
226+
/// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
227+
///
228+
/// ```rust
229+
/// # use quaint::{ast::*, visitor::{Visitor, Sqlite}};
230+
/// # fn main() -> Result<(), quaint::error::Error> {
231+
/// let join = "posts".alias("p").on(("p", "visible").equals(true));
232+
/// let joined_table = Table::from("users").right_join(join);
233+
/// let query = Select::from_table(joined_table).and_from("comments");
234+
/// let (sql, params) = Sqlite::build(query)?;
235+
///
236+
/// assert_eq!(
237+
/// "SELECT `users`.*, `comments`.* FROM \
238+
/// `users` RIGHT JOIN `posts` AS `p` ON `p`.`visible` = ?, \
239+
/// `comments`",
240+
/// sql
241+
/// );
242+
///
243+
/// assert_eq!(
244+
/// vec![
245+
/// Value::from(true),
246+
/// ],
247+
/// params
248+
/// );
249+
/// # Ok(())
250+
/// # }
251+
/// ```
252+
pub fn right_join<J>(mut self, join: J) -> Self
253+
where
254+
J: Into<JoinData<'a>>,
255+
{
256+
match self.typ {
257+
TableType::Table(table_name) => {
258+
self.typ = TableType::JoinedTable((table_name, vec![Join::Right(join.into())]))
259+
}
260+
TableType::JoinedTable((_, ref mut joins)) => joins.push(Join::Right(join.into())),
261+
TableType::Query(_) => {
262+
panic!("You cannot right_join on a table of type Query")
263+
}
264+
TableType::Values(_) => {
265+
panic!("You cannot right_join on a table of type Values")
266+
}
267+
}
268+
269+
self
270+
}
271+
272+
/// Adds a `FULL JOIN` clause to the query, specifically for that table.
273+
/// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
274+
///
275+
/// ```rust
276+
/// # use quaint::{ast::*, visitor::{Visitor, Sqlite}};
277+
/// # fn main() -> Result<(), quaint::error::Error> {
278+
/// let join = "posts".alias("p").on(("p", "visible").equals(true));
279+
/// let joined_table = Table::from("users").full_join(join);
280+
/// let query = Select::from_table(joined_table).and_from("comments");
281+
/// let (sql, params) = Sqlite::build(query)?;
282+
///
283+
/// assert_eq!(
284+
/// "SELECT `users`.*, `comments`.* FROM \
285+
/// `users` FULL JOIN `posts` AS `p` ON `p`.`visible` = ?, \
286+
/// `comments`",
287+
/// sql
288+
/// );
289+
///
290+
/// assert_eq!(
291+
/// vec![
292+
/// Value::from(true),
293+
/// ],
294+
/// params
295+
/// );
296+
/// # Ok(())
297+
/// # }
298+
/// ```
299+
pub fn full_join<J>(mut self, join: J) -> Self
300+
where
301+
J: Into<JoinData<'a>>,
302+
{
303+
match self.typ {
304+
TableType::Table(table_name) => {
305+
self.typ = TableType::JoinedTable((table_name, vec![Join::Full(join.into())]))
306+
}
307+
TableType::JoinedTable((_, ref mut joins)) => joins.push(Join::Full(join.into())),
308+
TableType::Query(_) => {
309+
panic!("You cannot full_join on a table of type Query")
310+
}
311+
TableType::Values(_) => {
312+
panic!("You cannot full_join on a table of type Values")
313+
}
314+
}
315+
316+
self
317+
}
129318
}
130319

131320
impl<'a> From<&'a str> for Table<'a> {

src/connector/sqlite.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ impl TryFrom<&str> for Sqlite {
111111

112112
let client = Mutex::new(conn);
113113

114-
Ok(Sqlite {
115-
client,
116-
})
114+
Ok(Sqlite { client })
117115
}
118116
}
119117

src/single.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,11 @@ impl Quaint {
169169
#[cfg_attr(feature = "docs", doc(cfg(sqlite)))]
170170
/// Open a new SQLite database in memory.
171171
pub fn new_in_memory() -> crate::Result<Quaint> {
172-
173172
Ok(Quaint {
174173
inner: Arc::new(connector::Sqlite::new_in_memory()?),
175-
connection_info: Arc::new(ConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned() }),
174+
connection_info: Arc::new(ConnectionInfo::InMemorySqlite {
175+
db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(),
176+
}),
176177
})
177178
}
178179

src/tests/query.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,59 @@ async fn inner_join(api: &mut dyn TestApi) -> crate::Result<()> {
350350
Ok(())
351351
}
352352

353+
#[test_each_connector]
354+
async fn table_inner_join(api: &mut dyn TestApi) -> crate::Result<()> {
355+
let table1 = api.create_table("id int, name varchar(255)").await?;
356+
let table2 = api.create_table("t1_id int, is_cat int").await?;
357+
let table3 = api.create_table("id int, foo int").await?;
358+
359+
let insert = Insert::multi_into(&table1, vec!["id", "name"])
360+
.values(vec![Value::integer(1), Value::text("Musti")])
361+
.values(vec![Value::integer(2), Value::text("Belka")]);
362+
363+
api.conn().insert(insert.into()).await?;
364+
365+
let insert = Insert::multi_into(&table2, vec!["t1_id", "is_cat"])
366+
.values(vec![Value::integer(1), Value::integer(1)])
367+
.values(vec![Value::integer(2), Value::integer(0)]);
368+
369+
api.conn().insert(insert.into()).await?;
370+
371+
let insert = Insert::multi_into(&table3, vec!["id", "foo"]).values(vec![Value::integer(1), Value::integer(1)]);
372+
373+
api.conn().insert(insert.into()).await?;
374+
375+
let joined_table = Table::from(&table1).inner_join(
376+
table2
377+
.as_str()
378+
.on((table1.as_str(), "id").equals(Column::from((&table2, "t1_id")))),
379+
);
380+
381+
let query = Select::from_table(joined_table)
382+
// Select from a third table to ensure that the JOIN is specifically applied on the table1
383+
.and_from(&table3)
384+
.column((&table1, "name"))
385+
.column((&table2, "is_cat"))
386+
.column((&table3, "foo"))
387+
.order_by(Column::from((&table1, "id")).ascend());
388+
389+
let res = api.conn().select(query).await?;
390+
391+
assert_eq!(2, res.len());
392+
393+
let row = res.get(0).unwrap();
394+
assert_eq!(Some("Musti"), row["name"].as_str());
395+
assert_eq!(Some(true), row["is_cat"].as_bool());
396+
assert_eq!(Some(true), row["foo"].as_bool());
397+
398+
let row = res.get(1).unwrap();
399+
assert_eq!(Some("Belka"), row["name"].as_str());
400+
assert_eq!(Some(false), row["is_cat"].as_bool());
401+
assert_eq!(Some(true), row["foo"].as_bool());
402+
403+
Ok(())
404+
}
405+
353406
#[test_each_connector]
354407
async fn left_join(api: &mut dyn TestApi) -> crate::Result<()> {
355408
let table1 = api.create_table("id int, name varchar(255)").await?;
@@ -391,6 +444,60 @@ async fn left_join(api: &mut dyn TestApi) -> crate::Result<()> {
391444
Ok(())
392445
}
393446

447+
#[test_each_connector]
448+
async fn table_left_join(api: &mut dyn TestApi) -> crate::Result<()> {
449+
let table1 = api.create_table("id int, name varchar(255)").await?;
450+
let table2 = api.create_table("t1_id int, is_cat int").await?;
451+
let table3 = api.create_table("id int, foo int").await?;
452+
453+
let insert = Insert::multi_into(&table1, vec!["id", "name"])
454+
.values(vec![Value::integer(1), Value::text("Musti")])
455+
.values(vec![Value::integer(2), Value::text("Belka")]);
456+
457+
api.conn().insert(insert.into()).await?;
458+
459+
let insert =
460+
Insert::multi_into(&table2, vec!["t1_id", "is_cat"]).values(vec![Value::integer(1), Value::integer(1)]);
461+
462+
api.conn().insert(insert.into()).await?;
463+
464+
let insert = Insert::multi_into(&table3, vec!["id", "foo"]).values(vec![Value::integer(1), Value::integer(1)]);
465+
466+
api.conn().insert(insert.into()).await?;
467+
468+
let joined_table = Table::from(&table1).left_join(
469+
table2
470+
.as_str()
471+
.on((&table1, "id").equals(Column::from((&table2, "t1_id")))),
472+
);
473+
474+
let query = Select::from_table(joined_table)
475+
// Select from a third table to ensure that the JOIN is specifically applied on the table1
476+
.and_from(&table3)
477+
.column((&table1, "name"))
478+
.column((&table2, "is_cat"))
479+
.column((&table3, "foo"))
480+
.order_by(Column::from((&table1, "id")).ascend());
481+
482+
let res = api.conn().select(query).await?;
483+
484+
println!("{:?}", &res);
485+
486+
assert_eq!(2, res.len());
487+
488+
let row = res.get(0).unwrap();
489+
assert_eq!(Some("Musti"), row["name"].as_str());
490+
assert_eq!(Some(true), row["is_cat"].as_bool());
491+
assert_eq!(Some(true), row["foo"].as_bool());
492+
493+
let row = res.get(1).unwrap();
494+
assert_eq!(Some("Belka"), row["name"].as_str());
495+
assert_eq!(None, row["is_cat"].as_bool());
496+
assert_eq!(Some(true), row["foo"].as_bool());
497+
498+
Ok(())
499+
}
500+
394501
#[test_each_connector]
395502
async fn limit_no_offset(api: &mut dyn TestApi) -> crate::Result<()> {
396503
let table = api.create_table("id int, name varchar(255)").await?;

0 commit comments

Comments
 (0)