Skip to content

Commit

Permalink
Change CONFLICT/DUPLICATE KEY UPDATE to use mods
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenafamo committed Jan 2, 2024
1 parent 8c24a51 commit 24a9651
Show file tree
Hide file tree
Showing 16 changed files with 146 additions and 83 deletions.
2 changes: 1 addition & 1 deletion clause/conflict.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (c Conflict) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error)
w.Write([]byte(" DO "))
w.Write([]byte(c.Do))

setArgs, err := bob.ExpressIf(w, d, start+len(args), c.Set, true, " ", "")
setArgs, err := bob.ExpressIf(w, d, start+len(args), c.Set, len(c.Set.Set) > 0, " SET\n", "")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion clause/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ func (s *Set) AppendSet(exprs ...any) {
}

func (s Set) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
return bob.ExpressSlice(w, d, start, s.Set, "SET\n", ",\n", "")
return bob.ExpressSlice(w, d, start, s.Set, "", ",\n", "")
}
4 changes: 2 additions & 2 deletions dialect/mysql/dialect/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type InsertQuery struct {
RowAlias string
ColumnAlias []string
Sets []Set
DuplicateKeyUpdate []Set
DuplicateKeyUpdate clause.Set
}

func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
Expand Down Expand Up @@ -104,7 +104,7 @@ func (i InsertQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err
}
}

updateArgs, err := bob.ExpressSlice(w, d, start+len(args), i.DuplicateKeyUpdate,
updateArgs, err := bob.ExpressSlice(w, d, start+len(args), i.DuplicateKeyUpdate.Set,
"\nON DUPLICATE KEY UPDATE\n", ",\n", "")
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion dialect/mysql/dialect/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err
}
args = append(args, fromArgs...)

setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " ", "")
setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "")
if err != nil {
return nil, err
}
Expand Down
53 changes: 34 additions & 19 deletions dialect/mysql/im/qm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package im

import (
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/clause"
"github.com/stephenafamo/bob/dialect/mysql/dialect"
"github.com/stephenafamo/bob/expr"
"github.com/stephenafamo/bob/internal"
"github.com/stephenafamo/bob/mods"
)

Expand Down Expand Up @@ -68,39 +70,52 @@ func As(rowAlias string, colAlias ...string) bob.Mod[*dialect.InsertQuery] {
})
}

func OnDuplicateKeyUpdate() *dupKeyUpdater {
return &dupKeyUpdater{}
}
func OnDuplicateKeyUpdate(clauses ...bob.Mod[*clause.Set]) bob.Mod[*dialect.InsertQuery] {
sets := clause.Set{}
for _, m := range clauses {
m.Apply(&sets)
}

type dupKeyUpdater struct {
sets []dialect.Set
return mods.QueryModFunc[*dialect.InsertQuery](func(q *dialect.InsertQuery) {
q.DuplicateKeyUpdate.Set = append(q.DuplicateKeyUpdate.Set, sets.Set...)
})
}

func (s dupKeyUpdater) Apply(q *dialect.InsertQuery) {
q.DuplicateKeyUpdate = append(q.DuplicateKeyUpdate, s.sets...)
//========================================
// For use in ON DUPLICATE KEY UPDATE
//========================================

func Update(exprs ...bob.Expression) bob.Mod[*clause.Set] {
return mods.QueryModFunc[*clause.Set](func(c *clause.Set) {
c.Set = append(c.Set, internal.ToAnySlice(exprs)...)
})
}

func (s *dupKeyUpdater) SetCol(col string, val any) *dupKeyUpdater {
s.sets = append(s.sets, dialect.Set{Col: col, Val: val})
return s
func UpdateCol(col string) mods.Set[*clause.Set] {
return mods.Set[*clause.Set]{col}
}

func (s *dupKeyUpdater) Set(alias string, cols ...string) *dupKeyUpdater {
newCols := make([]dialect.Set, len(cols))
func UpdateWithAlias(alias string, cols ...string) bob.Mod[*clause.Set] {
newCols := make([]any, len(cols))
for i, c := range cols {
newCols[i] = dialect.Set{Col: c, Val: expr.Quote(alias, c)}
}

s.sets = append(s.sets, newCols...)
return s
return mods.QueryModFunc[*clause.Set](func(s *clause.Set) {
s.Set = append(s.Set, newCols...)
})
}

func (s *dupKeyUpdater) SetValues(cols ...string) *dupKeyUpdater {
newCols := make([]dialect.Set, len(cols))
func UpdateWithValues(cols ...string) bob.Mod[*clause.Set] {
newCols := make([]any, len(cols))
for i, c := range cols {
newCols[i] = dialect.Set{Col: c, Val: dialect.NewFunction("VALUES", expr.Quote(c))}
newCols[i] = dialect.Set{
Col: c,
Val: dialect.NewFunction("VALUES", expr.Quote(c)),
}
}

s.sets = append(s.sets, newCols...)
return s
return mods.QueryModFunc[*clause.Set](func(s *clause.Set) {
s.Set = append(s.Set, newCols...)
})
}
9 changes: 5 additions & 4 deletions dialect/mysql/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ func TestInsert(t *testing.T) {
im.Values(mysql.Arg(8, "Anvil Distribution")),
im.Values(mysql.Arg(9, "Sentry Distribution")),
im.As("new"),
im.OnDuplicateKeyUpdate().
Set("new", "did").
SetCol("dbname", mysql.Concat(
im.OnDuplicateKeyUpdate(
im.UpdateWithAlias("new", "did"),
im.UpdateCol("dbname").To(mysql.Concat(
mysql.Quote("new", "dname"), mysql.S(" (formerly "),
mysql.Quote("d", "dname"), mysql.S(")"),
)),
),
),
ExpectedSQL: `INSERT INTO distributors (` + "`did`" + `, ` + "`dname`" + `)
VALUES (?, ?), (?, ?)
Expand All @@ -94,7 +95,7 @@ func TestInsert(t *testing.T) {
im.Into("distributors", "did", "dname"),
im.Values(mysql.Arg(8, "Anvil Distribution")),
im.Values(mysql.Arg(9, "Sentry Distribution")),
im.OnDuplicateKeyUpdate().SetValues("did", "dbname"),
im.OnDuplicateKeyUpdate(im.UpdateWithValues("did", "dbname")),
),
ExpectedSQL: `INSERT INTO distributors (` + "`did`" + `, ` + "`dname`" + `)
VALUES (?, ?), (?, ?)
Expand Down
2 changes: 1 addition & 1 deletion dialect/mysql/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut
updateCols = columns
}

conflictQM = im.OnDuplicateKeyUpdate().SetValues(updateCols...)
conflictQM = im.OnDuplicateKeyUpdate(im.UpdateWithValues(updateCols...))
}

q := Insert(
Expand Down
2 changes: 1 addition & 1 deletion dialect/psql/dialect/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err
}
args = append(args, tableArgs...)

setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " ", "")
setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "")
if err != nil {
return nil, err
}
Expand Down
40 changes: 40 additions & 0 deletions dialect/psql/im/qm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/clause"
"github.com/stephenafamo/bob/dialect/psql/dialect"
"github.com/stephenafamo/bob/expr"
"github.com/stephenafamo/bob/internal"
"github.com/stephenafamo/bob/mods"
)

Expand Down Expand Up @@ -85,3 +87,41 @@ func OnConflictOnConstraint(constraint string) mods.Conflict[*dialect.InsertQuer
func Returning(clauses ...any) bob.Mod[*dialect.InsertQuery] {
return mods.Returning[*dialect.InsertQuery](clauses)
}

//========================================
// For use in ON CONFLICT DO UPDATE SET
//========================================

func Set(sets ...bob.Expression) bob.Mod[*clause.Conflict] {
return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) {
c.Set.Set = append(c.Set.Set, internal.ToAnySlice(sets)...)
})
}

func SetCol(from string) mods.Set[*clause.Conflict] {
return mods.Set[*clause.Conflict]{from}
}

func SetExcluded(cols ...string) bob.Mod[*clause.Conflict] {
exprs := make([]any, 0, len(cols))
for _, col := range cols {
if col == "" {
continue
}
exprs = append(exprs,
expr.Join{Exprs: []bob.Expression{
expr.Quote(col), expr.Raw("= EXCLUDED."), expr.Quote(col),
}},
)
}

return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) {
c.Set.Set = append(c.Set.Set, exprs...)
})
}

func Where(e bob.Expression) bob.Mod[*clause.Conflict] {
return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) {
c.Where.Conditions = append(c.Where.Conditions, e)
})
}
17 changes: 9 additions & 8 deletions dialect/psql/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ func TestInsert(t *testing.T) {
im.IntoAs("distributors", "d", "did", "dname"),
im.Values(psql.Arg(8, "Anvil Distribution")),
im.Values(psql.Arg(9, "Sentry Distribution")),
im.OnConflict("did").DoUpdate().
Set("dname", psql.Concat(
im.OnConflict("did").DoUpdate(
im.SetCol("dname").To(psql.Concat(
psql.Raw("EXCLUDED.dname"), psql.S(" (formerly "),
psql.Quote("d", "dname"), psql.S(")"),
)).
Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))),
)),
im.Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))),
),
),
ExpectedSQL: `INSERT INTO distributors AS "d" ("did", "dname")
VALUES ($1, $2), ($3, $4)
Expand All @@ -68,10 +69,10 @@ func TestInsert(t *testing.T) {
im.IntoAs("distributors", "d", "did", "dname"),
im.Values(psql.Arg(8, "Anvil Distribution")),
im.Values(psql.Arg(9, "Sentry Distribution")),
im.OnConflictOnConstraint("distributors_pkey").
DoUpdate().
SetExcluded("dname").
Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))),
im.OnConflictOnConstraint("distributors_pkey").DoUpdate(
im.SetExcluded("dname"),
im.Where(psql.Quote("d", "zipcode").NE(psql.S("21201"))),
),
),
ExpectedSQL: `INSERT INTO distributors AS "d" ("did", "dname")
VALUES ($1, $2), ($3, $4)
Expand Down
3 changes: 1 addition & 2 deletions dialect/psql/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut
excludeSetCols = t.setMapping.NonPKs
}
conflictQM = im.OnConflict(internal.ToAnySlice(conflictCols)...).
DoUpdate().
SetExcluded(excludeSetCols...)
DoUpdate(im.SetExcluded(excludeSetCols...))
}

q := Insert(
Expand Down
2 changes: 1 addition & 1 deletion dialect/sqlite/dialect/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (u UpdateQuery) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, err
}
args = append(args, tableArgs...)

setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " ", "")
setArgs, err := bob.ExpressIf(w, d, start+len(args), u.Set, true, " SET\n", "")
if err != nil {
return nil, err
}
Expand Down
40 changes: 40 additions & 0 deletions dialect/sqlite/im/qm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/clause"
"github.com/stephenafamo/bob/dialect/sqlite/dialect"
"github.com/stephenafamo/bob/expr"
"github.com/stephenafamo/bob/internal"
"github.com/stephenafamo/bob/mods"
)

Expand Down Expand Up @@ -82,3 +84,41 @@ func OnConflict(columns ...any) mods.Conflict[*dialect.InsertQuery] {
func Returning(clauses ...any) bob.Mod[*dialect.InsertQuery] {
return mods.Returning[*dialect.InsertQuery](clauses)
}

//========================================
// For use in ON CONFLICT DO UPDATE SET
//========================================

func Set(sets ...bob.Expression) bob.Mod[*clause.Conflict] {
return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) {
c.Set.Set = append(c.Set.Set, internal.ToAnySlice(sets)...)
})
}

func SetCol(from string) mods.Set[*clause.Conflict] {
return mods.Set[*clause.Conflict]{from}
}

func SetExcluded(cols ...string) bob.Mod[*clause.Conflict] {
exprs := make([]any, 0, len(cols))
for _, col := range cols {
if col == "" {
continue
}
exprs = append(exprs,
expr.Join{Exprs: []bob.Expression{
expr.Quote(col), expr.Raw("= EXCLUDED."), expr.Quote(col),
}},
)
}

return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) {
c.Set.Set = append(c.Set.Set, exprs...)
})
}

func Where(e bob.Expression) bob.Mod[*clause.Conflict] {
return mods.QueryModFunc[*clause.Conflict](func(c *clause.Conflict) {
c.Where.Conditions = append(c.Where.Conditions, e)
})
}
7 changes: 4 additions & 3 deletions dialect/sqlite/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ func TestInsert(t *testing.T) {
im.IntoAs("distributors", "d", "did", "dname"),
im.Values(sqlite.Arg(8, "Anvil Distribution")),
im.Values(sqlite.Arg(9, "Sentry Distribution")),
im.OnConflict("did").DoUpdate().
SetExcluded("dname").
Where(sqlite.Quote("d", "zipcode").NE(sqlite.S("21201"))),
im.OnConflict("did").DoUpdate(
im.SetExcluded("dname"),
im.Where(sqlite.Quote("d", "zipcode").NE(sqlite.S("21201"))),
),
),
ExpectedSQL: `INSERT INTO distributors AS "d" ("did", "dname")
VALUES (?1, ?2), (?3, ?4)
Expand Down
3 changes: 1 addition & 2 deletions dialect/sqlite/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ func (t *Table[T, Tslice, Tset]) UpsertMany(ctx context.Context, exec bob.Execut
excludeSetCols = t.setMapping.NonPKs
}
conflictQM = im.OnConflict(internal.ToAnySlice(conflictCols)...).
DoUpdate().
SetExcluded(excludeSetCols...)
DoUpdate(im.SetExcluded(excludeSetCols...))
}

q := Insert(
Expand Down
Loading

0 comments on commit 24a9651

Please sign in to comment.