diff --git a/queries/qm/query_mods.go b/queries/qm/query_mods.go index e8bf752eb..96c2e6d36 100644 --- a/queries/qm/query_mods.go +++ b/queries/qm/query_mods.go @@ -194,13 +194,14 @@ func Distinct(clause string) QueryMod { } type withQueryMod struct { + alias string clause string args []interface{} } // Apply implements QueryMod.Apply. func (qm withQueryMod) Apply(q *queries.Query) { - queries.AppendWith(q, qm.clause, qm.args...) + queries.AppendWith(q, qm.alias, qm.clause, qm.args...) } // With allows you to pass in a Common Table Expression clause (and args) @@ -211,6 +212,17 @@ func With(clause string, args ...interface{}) QueryMod { } } +// WithSubquery allows you to generate a Common Table Expression using a query +// object to populate the CTE +func WithSubquery(alias string, q *queries.Query) QueryMod { + clause, args := queries.BuildSubquery(q) + return withQueryMod{ + alias: alias, + clause: clause, + args: args, + } +} + type selectQueryMod struct { columns []string } diff --git a/queries/query.go b/queries/query.go index 5e99d4ebd..97637fe34 100644 --- a/queries/query.go +++ b/queries/query.go @@ -32,7 +32,7 @@ type Query struct { delete bool update map[string]interface{} - withs []argClause + withs []with selectCols []string count bool from []string @@ -88,6 +88,12 @@ type argClause struct { args []interface{} } +type with struct { + alias string + clause string + args []interface{} +} + type rawSQL struct { sql string args []interface{} @@ -398,8 +404,8 @@ func AppendOrderBy(q *Query, clause string, args ...interface{}) { } // AppendWith on the query. -func AppendWith(q *Query, clause string, args ...interface{}) { - q.withs = append(q.withs, argClause{clause: clause, args: args}) +func AppendWith(q *Query, alias, clause string, args ...interface{}) { + q.withs = append(q.withs, with{alias: alias, clause: clause, args: args}) } // RemoveSoftDeleteWhere prevents the automatic soft delete where clause diff --git a/queries/query_builders.go b/queries/query_builders.go index 7f5cc1a72..77c83f478 100644 --- a/queries/query_builders.go +++ b/queries/query_builders.go @@ -20,6 +20,18 @@ var ( // and it's accompanying arguments. Using this method // allows query building without immediate execution. func BuildQuery(q *Query) (string, []interface{}) { + return buildQuery(q, true) +} + +// BuildSubquery builds a query object into the query string +// and it's accompanying arguments but doesn't append a +// semi-colon allowing the resulting string to be embedded +// as a subquery. +func BuildSubquery(q *Query) (string, []interface{}) { + return buildQuery(q, false) +} + +func buildQuery(q *Query, finalize bool) (string, []interface{}) { var buf *bytes.Buffer var args []interface{} @@ -36,6 +48,10 @@ func BuildQuery(q *Query) (string, []interface{}) { buf, args = buildSelectQuery(q) } + if finalize { + buf.WriteByte(';') + } + defer strmangle.PutBuffer(buf) // Cache the generated query for query object re-use @@ -133,7 +149,6 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { writeModifiers(q, buf, &args) - buf.WriteByte(';') return buf, args } @@ -155,8 +170,6 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { writeModifiers(q, buf, &args) - buf.WriteByte(';') - return buf, args } @@ -199,8 +212,6 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { writeModifiers(q, buf, &args) - buf.WriteByte(';') - return buf, args } @@ -614,7 +625,11 @@ func writeCTEs(q *Query, buf *bytes.Buffer, args *[]interface{}) { withBuf := strmangle.GetBuffer() lastPos := len(q.withs) - 1 for i, w := range q.withs { - fmt.Fprintf(withBuf, " %s", w.clause) + if w.alias != "" { + fmt.Fprintf(withBuf, " %s AS (%s)", w.alias, w.clause) + } else { + fmt.Fprintf(withBuf, " %s", w.clause) + } if i >= 0 && i < lastPos { withBuf.WriteByte(',') } diff --git a/queries/query_builders_test.go b/queries/query_builders_test.go index 892d8dde0..72e94bfb9 100644 --- a/queries/query_builders_test.go +++ b/queries/query_builders_test.go @@ -119,9 +119,9 @@ func TestBuildQuery(t *testing.T) { {&Query{from: []string{"cats as c", "dogs as d"}, joins: []join{{JoinOuterFull, "dogs d on d.cat_id = cats.id", nil}}}, nil}, {&Query{ from: []string{"t"}, - withs: []argClause{ - {"cte_0 AS (SELECT * FROM other_t0)", nil}, - {"cte_1 AS (SELECT * FROM other_t1 WHERE thing=? AND stuff=?)", []interface{}{3, 7}}, + withs: []with{ + {"cte_0", "SELECT * FROM other_t0", nil}, + {"cte_1", "SELECT * FROM other_t1 WHERE thing=? AND stuff=?", []interface{}{3, 7}}, }, }, []interface{}{3, 7}, }, @@ -161,6 +161,31 @@ func TestBuildQuery(t *testing.T) { } } +func TestBuildSubquery(t *testing.T) { + t.Parallel() + + q1 := &Query{} + SetSelect(q1, []string{"foo", "bar"}) + SetFrom(q1, "tbl") + q1.dialect = &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true} + + q2 := &Query{} + SetSelect(q2, []string{"foo", "bar"}) + SetFrom(q2, "tbl") + q2.dialect = &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true} + + query, _ := BuildQuery(q1) + subquery, _ := BuildSubquery(q2) + + if !strings.HasSuffix(query, ";") { + t.Error("BuildQuery() result is missing trailing ';'\n", query) + } + + if strings.HasSuffix(subquery, ";") { + t.Error("BuildSubquery() result has trailing ';'\n", subquery) + } +} + func TestWriteStars(t *testing.T) { t.Parallel() @@ -516,11 +541,11 @@ func TestLimitClause(t *testing.T) { t.Parallel() tests := []struct { - limit *int + limit *int expectPredicate func(sql string) bool }{ {nil, func(sql string) bool { - return !strings.Contains(sql,"LIMIT") + return !strings.Contains(sql, "LIMIT") }}, {newIntPtr(0), func(sql string) bool { return strings.Contains(sql, "LIMIT 0") @@ -532,7 +557,7 @@ func TestLimitClause(t *testing.T) { for i, test := range tests { q := &Query{ - limit: test.limit, + limit: test.limit, dialect: &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true, UseTopClause: false}, } sql, _ := BuildQuery(q) diff --git a/queries/query_test.go b/queries/query_test.go index f813ea61e..9f6b79191 100644 --- a/queries/query_test.go +++ b/queries/query_test.go @@ -604,17 +604,17 @@ func TestAppendWith(t *testing.T) { t.Parallel() q := &Query{} - AppendWith(q, "cte_0 AS (SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2)", 5, 10) - AppendWith(q, "cte_1 AS (SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2)", 5, 10) + AppendWith(q, "cte_0", "SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2", 5, 10) + AppendWith(q, "cte_1", "SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2", 5, 10) if len(q.withs) != 2 { t.Errorf("Expected len 2, got %d", len(q.withs)) } - if q.withs[0].clause != "cte_0 AS (SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2)" { + if q.withs[0].alias != "cte_0" || q.withs[0].clause != "SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2" { t.Errorf("Got invalid with on string: %#v", q.withs) } - if q.withs[1].clause != "cte_1 AS (SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2)" { + if q.withs[1].alias != "cte_1" || q.withs[1].clause != "SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2" { t.Errorf("Got invalid with on string: %#v", q.withs) } @@ -629,8 +629,9 @@ func TestAppendWith(t *testing.T) { t.Errorf("Invalid args values, got %#v", q.withs[0].args) } - q.withs = []argClause{{ - clause: "other_cte AS (SELECT * FROM other_table WHERE thing=$1 AND stuff=$2)", + q.withs = []with{{ + alias: "other_cte", + clause: "SELECT * FROM other_table WHERE thing=$1 AND stuff=$2", args: []interface{}{3, 7}, }} @@ -638,7 +639,7 @@ func TestAppendWith(t *testing.T) { t.Errorf("Expected len 1, got %d", len(q.withs)) } - if q.withs[0].clause != "other_cte AS (SELECT * FROM other_table WHERE thing=$1 AND stuff=$2)" { + if q.withs[0].alias != "other_cte" || q.withs[0].clause != "SELECT * FROM other_table WHERE thing=$1 AND stuff=$2" { t.Errorf("Got invalid with on string: %#v", q.withs) } }