From 24a9651b3ed4fbd800f80755731939a5599a6019 Mon Sep 17 00:00:00 2001 From: Stephen Afam-Osemene Date: Tue, 2 Jan 2024 19:57:10 +0000 Subject: [PATCH] Change CONFLICT/DUPLICATE KEY UPDATE to use mods --- clause/conflict.go | 2 +- clause/set.go | 2 +- dialect/mysql/dialect/insert.go | 4 +-- dialect/mysql/dialect/update.go | 2 +- dialect/mysql/im/qm.go | 53 ++++++++++++++++++++------------ dialect/mysql/insert_test.go | 9 +++--- dialect/mysql/table.go | 2 +- dialect/psql/dialect/update.go | 2 +- dialect/psql/im/qm.go | 40 ++++++++++++++++++++++++ dialect/psql/insert_test.go | 17 +++++----- dialect/psql/table.go | 3 +- dialect/sqlite/dialect/update.go | 2 +- dialect/sqlite/im/qm.go | 40 ++++++++++++++++++++++++ dialect/sqlite/insert_test.go | 7 +++-- dialect/sqlite/table.go | 3 +- mods/conflict.go | 41 +++--------------------- 16 files changed, 146 insertions(+), 83 deletions(-) diff --git a/clause/conflict.go b/clause/conflict.go index ee4af302..234c8eec 100644 --- a/clause/conflict.go +++ b/clause/conflict.go @@ -28,7 +28,7 @@ func (c Conflict) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) w.Write([]byte(" DO ")) w.Write([]byte(c.Do)) - setArgs, err := bob.ExpressIf(w, d, start+len(args), c.Set, true, " ", "") + setArgs, err := bob.ExpressIf(w, d, start+len(args), c.Set, len(c.Set.Set) > 0, " SET\n", "") if err != nil { return nil, err } diff --git a/clause/set.go b/clause/set.go index 672510d9..66f61871 100644 --- a/clause/set.go +++ b/clause/set.go @@ -15,5 +15,5 @@ func (s *Set) AppendSet(exprs ...any) { } func (s Set) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - return bob.ExpressSlice(w, d, start, s.Set, "SET\n", ",\n", "") + return bob.ExpressSlice(w, d, start, s.Set, "", ",\n", "") } diff --git a/dialect/mysql/dialect/insert.go b/dialect/mysql/dialect/insert.go index f051e601..4521d3ef 100644 --- a/dialect/mysql/dialect/insert.go +++ b/dialect/mysql/dialect/insert.go @@ -21,7 +21,7 @@ type InsertQuery struct { RowAlias string ColumnAlias []string Sets []Set - DuplicateKeyUpdate []Set + DuplicateKeyUpdate clause.Set } func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { @@ -104,7 +104,7 @@ func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err } } - updateArgs, err := bob.ExpressSlice(w, d, start+len(args), i.DuplicateKeyUpdate, + updateArgs, err := bob.ExpressSlice(w, d, start+len(args), i.DuplicateKeyUpdate.Set, "\nON DUPLICATE KEY UPDATE\n", ",\n", "") if err != nil { return nil, err diff --git a/dialect/mysql/dialect/update.go b/dialect/mysql/dialect/update.go index 7a08420c..da982965 100644 --- a/dialect/mysql/dialect/update.go +++ b/dialect/mysql/dialect/update.go @@ -54,7 +54,7 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err } args = append(args, fromArgs...) - setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " ", "") + setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "") if err != nil { return nil, err } diff --git a/dialect/mysql/im/qm.go b/dialect/mysql/im/qm.go index b0d919f4..48e2cd79 100644 --- a/dialect/mysql/im/qm.go +++ b/dialect/mysql/im/qm.go @@ -2,8 +2,10 @@ package im import ( "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/clause" "github.com/stephenafamo/bob/dialect/mysql/dialect" "github.com/stephenafamo/bob/expr" + "github.com/stephenafamo/bob/internal" "github.com/stephenafamo/bob/mods" ) @@ -68,39 +70,52 @@ func As(rowAlias string, colAlias ...string) bob.Mod[*dialect.InsertQuery] { }) } -func OnDuplicateKeyUpdate() *dupKeyUpdater { - return &dupKeyUpdater{} -} +func OnDuplicateKeyUpdate(clauses ...bob.Mod[*clause.Set]) bob.Mod[*dialect.InsertQuery] { + sets := clause.Set{} + for _, m := range clauses { + m.Apply(&sets) + } -type dupKeyUpdater struct { - sets []dialect.Set + return mods.QueryModFunc[*dialect.InsertQuery](func(q *dialect.InsertQuery) { + q.DuplicateKeyUpdate.Set = append(q.DuplicateKeyUpdate.Set, sets.Set...) + }) } -func (s dupKeyUpdater) Apply(q *dialect.InsertQuery) { - q.DuplicateKeyUpdate = append(q.DuplicateKeyUpdate, s.sets...) +//======================================== +// For use in ON DUPLICATE KEY UPDATE +//======================================== + +func Update(exprs ...bob.Expression) bob.Mod[*clause.Set] { + return mods.QueryModFunc[*clause.Set](func(c *clause.Set) { + c.Set = append(c.Set, internal.ToAnySlice(exprs)...) + }) } -func (s *dupKeyUpdater) SetCol(col string, val any) *dupKeyUpdater { - s.sets = append(s.sets, dialect.Set{Col: col, Val: val}) - return s +func UpdateCol(col string) mods.Set[*clause.Set] { + return mods.Set[*clause.Set]{col} } -func (s *dupKeyUpdater) Set(alias string, cols ...string) *dupKeyUpdater { - newCols := make([]dialect.Set, len(cols)) +func UpdateWithAlias(alias string, cols ...string) bob.Mod[*clause.Set] { + newCols := make([]any, len(cols)) for i, c := range cols { newCols[i] = dialect.Set{Col: c, Val: expr.Quote(alias, c)} } - s.sets = append(s.sets, newCols...) - return s + return mods.QueryModFunc[*clause.Set](func(s *clause.Set) { + s.Set = append(s.Set, newCols...) + }) } -func (s *dupKeyUpdater) SetValues(cols ...string) *dupKeyUpdater { - newCols := make([]dialect.Set, len(cols)) +func UpdateWithValues(cols ...string) bob.Mod[*clause.Set] { + newCols := make([]any, len(cols)) for i, c := range cols { - newCols[i] = dialect.Set{Col: c, Val: dialect.NewFunction("VALUES", expr.Quote(c))} + newCols[i] = dialect.Set{ + Col: c, + Val: dialect.NewFunction("VALUES", expr.Quote(c)), + } } - s.sets = append(s.sets, newCols...) - return s + return mods.QueryModFunc[*clause.Set](func(s *clause.Set) { + s.Set = append(s.Set, newCols...) + }) } diff --git a/dialect/mysql/insert_test.go b/dialect/mysql/insert_test.go index 26aaa787..b0a673f1 100644 --- a/dialect/mysql/insert_test.go +++ b/dialect/mysql/insert_test.go @@ -74,12 +74,13 @@ func TestInsert(t *testing.T) { im.Values(mysql.Arg(8, "Anvil Distribution")), im.Values(mysql.Arg(9, "Sentry Distribution")), im.As("new"), - im.OnDuplicateKeyUpdate(). - Set("new", "did"). - SetCol("dbname", mysql.Concat( + im.OnDuplicateKeyUpdate( + im.UpdateWithAlias("new", "did"), + im.UpdateCol("dbname").To(mysql.Concat( mysql.Quote("new", "dname"), mysql.S(" (formerly "), mysql.Quote("d", "dname"), mysql.S(")"), )), + ), ), ExpectedSQL: `INSERT INTO distributors (` + "`did`" + `, ` + "`dname`" + `) VALUES (?, ?), (?, ?) @@ -94,7 +95,7 @@ func TestInsert(t *testing.T) { im.Into("distributors", "did", "dname"), im.Values(mysql.Arg(8, "Anvil Distribution")), im.Values(mysql.Arg(9, "Sentry Distribution")), - im.OnDuplicateKeyUpdate().SetValues("did", "dbname"), + im.OnDuplicateKeyUpdate(im.UpdateWithValues("did", "dbname")), ), ExpectedSQL: `INSERT INTO distributors (` + "`did`" + `, ` + "`dname`" + `) VALUES (?, ?), (?, ?) diff --git a/dialect/mysql/table.go b/dialect/mysql/table.go index 1428f137..f7053017 100644 --- a/dialect/mysql/table.go +++ b/dialect/mysql/table.go @@ -276,7 +276,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut updateCols = columns } - conflictQM = im.OnDuplicateKeyUpdate().SetValues(updateCols...) + conflictQM = im.OnDuplicateKeyUpdate(im.UpdateWithValues(updateCols...)) } q := Insert( diff --git a/dialect/psql/dialect/update.go b/dialect/psql/dialect/update.go index dc06c0ab..5b7f0bb1 100644 --- a/dialect/psql/dialect/update.go +++ b/dialect/psql/dialect/update.go @@ -41,7 +41,7 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err } args = append(args, tableArgs...) - setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " ", "") + setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "") if err != nil { return nil, err } diff --git a/dialect/psql/im/qm.go b/dialect/psql/im/qm.go index 2835c523..6f1a5d61 100644 --- a/dialect/psql/im/qm.go +++ b/dialect/psql/im/qm.go @@ -4,6 +4,8 @@ import ( "github.com/stephenafamo/bob" "github.com/stephenafamo/bob/clause" "github.com/stephenafamo/bob/dialect/psql/dialect" + "github.com/stephenafamo/bob/expr" + "github.com/stephenafamo/bob/internal" "github.com/stephenafamo/bob/mods" ) @@ -85,3 +87,41 @@ func OnConflictOnConstraint(constraint string) mods.Conflict[*dialect.InsertQuer func Returning(clauses ...any) bob.Mod[*dialect.InsertQuery] { return mods.Returning[*dialect.InsertQuery](clauses) } + +//======================================== +// For use in ON CONFLICT DO UPDATE SET +//======================================== + +func Set(sets ...bob.Expression) bob.Mod[*clause.Conflict] { + return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) { + c.Set.Set = append(c.Set.Set, internal.ToAnySlice(sets)...) + }) +} + +func SetCol(from string) mods.Set[*clause.Conflict] { + return mods.Set[*clause.Conflict]{from} +} + +func SetExcluded(cols ...string) bob.Mod[*clause.Conflict] { + exprs := make([]any, 0, len(cols)) + for _, col := range cols { + if col == "" { + continue + } + exprs = append(exprs, + expr.Join{Exprs: []bob.Expression{ + expr.Quote(col), expr.Raw("= EXCLUDED."), expr.Quote(col), + }}, + ) + } + + return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) { + c.Set.Set = append(c.Set.Set, exprs...) + }) +} + +func Where(e bob.Expression) bob.Mod[*clause.Conflict] { + return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) { + c.Where.Conditions = append(c.Where.Conditions, e) + }) +} diff --git a/dialect/psql/insert_test.go b/dialect/psql/insert_test.go index c51cebdb..167862f5 100644 --- a/dialect/psql/insert_test.go +++ b/dialect/psql/insert_test.go @@ -49,12 +49,13 @@ func TestInsert(t *testing.T) { im.IntoAs("distributors", "d", "did", "dname"), im.Values(psql.Arg(8, "Anvil Distribution")), im.Values(psql.Arg(9, "Sentry Distribution")), - im.OnConflict("did").DoUpdate(). - Set("dname", psql.Concat( + im.OnConflict("did").DoUpdate( + im.SetCol("dname").To(psql.Concat( psql.Raw("EXCLUDED.dname"), psql.S(" (formerly "), psql.Quote("d", "dname"), psql.S(")"), - )). - Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))), + )), + im.Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))), + ), ), ExpectedSQL: `INSERT INTO distributors AS "d" ("did", "dname") VALUES ($1, $2), ($3, $4) @@ -68,10 +69,10 @@ func TestInsert(t *testing.T) { im.IntoAs("distributors", "d", "did", "dname"), im.Values(psql.Arg(8, "Anvil Distribution")), im.Values(psql.Arg(9, "Sentry Distribution")), - im.OnConflictOnConstraint("distributors_pkey"). - DoUpdate(). - SetExcluded("dname"). - Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))), + im.OnConflictOnConstraint("distributors_pkey").DoUpdate( + im.SetExcluded("dname"), + im.Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))), + ), ), ExpectedSQL: `INSERT INTO distributors AS "d" ("did", "dname") VALUES ($1, $2), ($3, $4) diff --git a/dialect/psql/table.go b/dialect/psql/table.go index f0e3dca0..2b66006f 100644 --- a/dialect/psql/table.go +++ b/dialect/psql/table.go @@ -212,8 +212,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut excludeSetCols = t.setMapping.NonPKs } conflictQM = im.OnConflict(internal.ToAnySlice(conflictCols)...). - DoUpdate(). - SetExcluded(excludeSetCols...) + DoUpdate(im.SetExcluded(excludeSetCols...)) } q := Insert( diff --git a/dialect/sqlite/dialect/update.go b/dialect/sqlite/dialect/update.go index 08e26e80..ea51289f 100644 --- a/dialect/sqlite/dialect/update.go +++ b/dialect/sqlite/dialect/update.go @@ -42,7 +42,7 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err } args = append(args, tableArgs...) - setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " ", "") + setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "") if err != nil { return nil, err } diff --git a/dialect/sqlite/im/qm.go b/dialect/sqlite/im/qm.go index 1a167190..c7d982a5 100644 --- a/dialect/sqlite/im/qm.go +++ b/dialect/sqlite/im/qm.go @@ -4,6 +4,8 @@ import ( "github.com/stephenafamo/bob" "github.com/stephenafamo/bob/clause" "github.com/stephenafamo/bob/dialect/sqlite/dialect" + "github.com/stephenafamo/bob/expr" + "github.com/stephenafamo/bob/internal" "github.com/stephenafamo/bob/mods" ) @@ -82,3 +84,41 @@ func OnConflict(columns ...any) mods.Conflict[*dialect.InsertQuery] { func Returning(clauses ...any) bob.Mod[*dialect.InsertQuery] { return mods.Returning[*dialect.InsertQuery](clauses) } + +//======================================== +// For use in ON CONFLICT DO UPDATE SET +//======================================== + +func Set(sets ...bob.Expression) bob.Mod[*clause.Conflict] { + return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) { + c.Set.Set = append(c.Set.Set, internal.ToAnySlice(sets)...) + }) +} + +func SetCol(from string) mods.Set[*clause.Conflict] { + return mods.Set[*clause.Conflict]{from} +} + +func SetExcluded(cols ...string) bob.Mod[*clause.Conflict] { + exprs := make([]any, 0, len(cols)) + for _, col := range cols { + if col == "" { + continue + } + exprs = append(exprs, + expr.Join{Exprs: []bob.Expression{ + expr.Quote(col), expr.Raw("= EXCLUDED."), expr.Quote(col), + }}, + ) + } + + return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) { + c.Set.Set = append(c.Set.Set, exprs...) + }) +} + +func Where(e bob.Expression) bob.Mod[*clause.Conflict] { + return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) { + c.Where.Conditions = append(c.Where.Conditions, e) + }) +} diff --git a/dialect/sqlite/insert_test.go b/dialect/sqlite/insert_test.go index 38348091..a0108ca2 100644 --- a/dialect/sqlite/insert_test.go +++ b/dialect/sqlite/insert_test.go @@ -58,9 +58,10 @@ func TestInsert(t *testing.T) { im.IntoAs("distributors", "d", "did", "dname"), im.Values(sqlite.Arg(8, "Anvil Distribution")), im.Values(sqlite.Arg(9, "Sentry Distribution")), - im.OnConflict("did").DoUpdate(). - SetExcluded("dname"). - Where(sqlite.Quote("d", "zipcode").NE(sqlite.S("21201"))), + im.OnConflict("did").DoUpdate( + im.SetExcluded("dname"), + im.Where(sqlite.Quote("d", "zipcode").NE(sqlite.S("21201"))), + ), ), ExpectedSQL: `INSERT INTO distributors AS "d" ("did", "dname") VALUES (?1, ?2), (?3, ?4) diff --git a/dialect/sqlite/table.go b/dialect/sqlite/table.go index 99de0244..e5ffe973 100644 --- a/dialect/sqlite/table.go +++ b/dialect/sqlite/table.go @@ -211,8 +211,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut excludeSetCols = t.setMapping.NonPKs } conflictQM = im.OnConflict(internal.ToAnySlice(conflictCols)...). - DoUpdate(). - SetExcluded(excludeSetCols...) + DoUpdate(im.SetExcluded(excludeSetCols...)) } q := Insert( diff --git a/mods/conflict.go b/mods/conflict.go index e9847645..a2a9dc04 100644 --- a/mods/conflict.go +++ b/mods/conflict.go @@ -3,7 +3,6 @@ package mods import ( "github.com/stephenafamo/bob" "github.com/stephenafamo/bob/clause" - "github.com/stephenafamo/bob/expr" ) type Conflict[Q interface{ SetConflict(clause.Conflict) }] func() clause.Conflict @@ -12,7 +11,7 @@ func (s Conflict[Q]) Apply(q Q) { q.SetConflict(s()) } -func (c Conflict[Q]) OnWhere(where ...any) Conflict[Q] { +func (c Conflict[Q]) Where(where ...any) Conflict[Q] { conflict := c() conflict.Target.Where = append(conflict.Target.Where, where...) @@ -30,45 +29,13 @@ func (c Conflict[Q]) DoNothing() bob.Mod[Q] { }) } -func (c Conflict[Q]) DoUpdate() Conflict[Q] { +func (c Conflict[Q]) DoUpdate(sets ...bob.Mod[*clause.Conflict]) bob.Mod[Q] { conflict := c() conflict.Do = "UPDATE" - return Conflict[Q](func() clause.Conflict { - return conflict - }) -} - -func (c Conflict[Q]) Set(a, b any) Conflict[Q] { - conflict := c() - conflict.Set.Set = append(conflict.Set.Set, expr.OP("=", a, b)) - - return Conflict[Q](func() clause.Conflict { - return conflict - }) -} - -func (c Conflict[Q]) SetExcluded(cols ...string) Conflict[Q] { - conflict := c() - exprs := make([]any, 0, len(cols)) - for _, col := range cols { - if col == "" { - continue - } - exprs = append(exprs, - expr.Join{Exprs: []bob.Expression{expr.Quote(col), expr.Raw("= EXCLUDED."), expr.Quote(col)}}, - ) + for _, set := range sets { + set.Apply(&conflict) } - conflict.Set.Set = append(conflict.Set.Set, exprs...) - - return Conflict[Q](func() clause.Conflict { - return conflict - }) -} - -func (c Conflict[Q]) Where(where ...any) Conflict[Q] { - conflict := c() - conflict.Where.Conditions = append(conflict.Where.Conditions, where...) return Conflict[Q](func() clause.Conflict { return conflict