diff --git a/CHANGELOG.md b/CHANGELOG.md index ceb8d399..53a03ff4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed unnecessary import of `strings` in `bobfactory_random.go`. - Fixed data races in unit tests. +- Fixed invalid SQL statements generated by `sm.OrderBy().Collate()`. ## [v0.28.1] - 2024-06-28 diff --git a/clause/order_by.go b/clause/order_by.go index 64c84ec4..dee34ae4 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -24,23 +24,25 @@ func (o OrderBy) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) } type OrderDef struct { - Expression any - Direction string // ASC | DESC | USING operator - Nulls string // FIRST | LAST - CollationName string + Expression any + Direction string // ASC | DESC | USING operator + Nulls string // FIRST | LAST + Collation bob.Expression } func (o OrderDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) { - if o.CollationName != "" { - w.Write([]byte("COLLATE ")) - w.Write([]byte(o.CollationName)) - } - args, err := bob.Express(w, d, start, o.Expression) if err != nil { return nil, err } + if o.Collation != nil { + _, err = o.Collation.WriteSQL(w, d, start) + if err != nil { + return nil, err + } + } + if o.Direction != "" { w.Write([]byte(" ")) w.Write([]byte(o.Direction)) diff --git a/dialect/mysql/dialect/mods.go b/dialect/mysql/dialect/mods.go index 17b7919f..5a63b4d2 100644 --- a/dialect/mysql/dialect/mods.go +++ b/dialect/mysql/dialect/mods.go @@ -1,6 +1,7 @@ package dialect import ( + "fmt" "io" "github.com/stephenafamo/bob" @@ -259,6 +260,17 @@ func (j JoinChain[Q]) Using(using ...string) bob.Mod[Q] { return mods.Join[Q](jo) } +type collation struct { + name string +} + +func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) { + if _, err := fmt.Fprintf(w, " COLLATE %s", c.name); err != nil { + return nil, err + } + return nil, nil +} + type OrderBy[Q interface{ AppendOrder(clause.OrderDef) }] func() clause.OrderDef func (s OrderBy[Q]) Apply(q Q) { @@ -283,9 +295,9 @@ func (o OrderBy[Q]) Desc() OrderBy[Q] { }) } -func (o OrderBy[Q]) Collate(collation string) OrderBy[Q] { +func (o OrderBy[Q]) Collate(collationName string) OrderBy[Q] { order := o() - order.CollationName = collation + order.Collation = collation{name: collationName} return OrderBy[Q](func() clause.OrderDef { return order diff --git a/dialect/mysql/select_test.go b/dialect/mysql/select_test.go index b13dc126..d4461bda 100644 --- a/dialect/mysql/select_test.go +++ b/dialect/mysql/select_test.go @@ -68,6 +68,14 @@ func TestSelect(t *testing.T) { ExpectedSQL: "SELECT id, name FROM users WHERE ((`id`, `employee_id`) IN ((?, ?), (?, ?)))", ExpectedArgs: []any{100, 200, 300, 400}, }, + "select with order by and collate": { + Query: mysql.Select( + sm.Columns("id", "name"), + sm.From("users"), + sm.OrderBy("name").Collate("utf8mb4_bg_0900_as_cs").Asc(), + ), + ExpectedSQL: "SELECT id, name FROM users ORDER BY name COLLATE utf8mb4_bg_0900_as_cs ASC", + }, } testutils.RunTests(t, examples, formatter) diff --git a/dialect/psql/dialect/mods.go b/dialect/psql/dialect/mods.go index 2a43b3cd..99cde8c7 100644 --- a/dialect/psql/dialect/mods.go +++ b/dialect/psql/dialect/mods.go @@ -197,6 +197,18 @@ func (j JoinChain[Q]) Using(using ...string) bob.Mod[Q] { return mods.Join[Q](jo) } +type collation struct { + name string +} + +func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) { + if _, err := w.Write([]byte(" COLLATE ")); err != nil { + return nil, err + } + d.WriteQuoted(w, c.name) + return nil, nil +} + type OrderBy[Q interface{ AppendOrder(clause.OrderDef) }] func() clause.OrderDef func (s OrderBy[Q]) Apply(q Q) { @@ -248,9 +260,9 @@ func (o OrderBy[Q]) NullsLast() OrderBy[Q] { }) } -func (o OrderBy[Q]) Collate(collation string) OrderBy[Q] { +func (o OrderBy[Q]) Collate(collationName string) OrderBy[Q] { order := o() - order.CollationName = collation + order.Collation = collation{name: collationName} return OrderBy[Q](func() clause.OrderDef { return order diff --git a/dialect/psql/select_test.go b/dialect/psql/select_test.go index 7db4ff2d..f9ae7738 100644 --- a/dialect/psql/select_test.go +++ b/dialect/psql/select_test.go @@ -162,6 +162,14 @@ WINDOW w AS (PARTITION BY depname ORDER BY salary)`, sm.Window("w").PartitionBy("depname").OrderBy("salary"), ), }, + "select with order by and collate": { + Query: psql.Select( + sm.Columns("id", "name"), + sm.From("users"), + sm.OrderBy("name").Collate("bg-BG-x-icu").Asc(), + ), + ExpectedSQL: `SELECT id, name FROM users ORDER BY name COLLATE "bg-BG-x-icu" ASC`, + }, } testutils.RunTests(t, examples, formatter) diff --git a/dialect/sqlite/dialect/mods.go b/dialect/sqlite/dialect/mods.go index 141e4f3b..09ca12b2 100644 --- a/dialect/sqlite/dialect/mods.go +++ b/dialect/sqlite/dialect/mods.go @@ -1,6 +1,7 @@ package dialect import ( + "fmt" "io" "github.com/stephenafamo/bob" @@ -228,15 +229,26 @@ func CrossJoin[Q Joinable](e any) bob.Mod[Q] { return Join[Q](clause.CrossJoin, e) } +type collation struct { + name string +} + +func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) { + if _, err := fmt.Fprintf(w, " COLLATE %s", c.name); err != nil { + return nil, err + } + return nil, nil +} + type OrderBy[Q interface{ AppendOrder(clause.OrderDef) }] func() clause.OrderDef func (s OrderBy[Q]) Apply(q Q) { q.AppendOrder(s()) } -func (o OrderBy[Q]) Collate(collation string) OrderBy[Q] { +func (o OrderBy[Q]) Collate(collationName string) OrderBy[Q] { order := o() - order.CollationName = collation + order.Collation = &collation{name: collationName} return OrderBy[Q](func() clause.OrderDef { return order diff --git a/dialect/sqlite/select_test.go b/dialect/sqlite/select_test.go index e7e86259..e85feeb1 100644 --- a/dialect/sqlite/select_test.go +++ b/dialect/sqlite/select_test.go @@ -75,6 +75,14 @@ func TestSelect(t *testing.T) { ExpectedSQL: `SELECT id, name FROM users WHERE (("id", "employee_id") IN ((?1, ?2), (?3, ?4)))`, ExpectedArgs: []any{100, 200, 300, 400}, }, + "select with order by and collate": { + Query: sqlite.Select( + sm.Columns("id", "name"), + sm.From("users"), + sm.OrderBy("name").Collate("NOCASE").Asc(), + ), + ExpectedSQL: `SELECT id, name FROM users ORDER BY name COLLATE NOCASE ASC`, + }, } testutils.RunTests(t, examples, formatter) diff --git a/gen/output_test.go b/gen/output_test.go index bda0371a..ad3b61a7 100644 --- a/gen/output_test.go +++ b/gen/output_test.go @@ -158,12 +158,12 @@ func TestGetOutputFilename(t *testing.T) { t.Run(name, func(t *testing.T) { notTest := getOutputFilename(tc.SchemaName, tc.TableName, false, tc.IsGo) if diff := cmp.Diff(tc.Expected, notTest); diff != "" { - t.Fatalf(diff) + t.Fatal(diff) } isTest := getOutputFilename(tc.SchemaName, tc.TableName, true, tc.IsGo) if diff := cmp.Diff(tc.Expected+"_test", isTest); diff != "" { - t.Fatalf(diff) + t.Fatal(diff) } }) } diff --git a/website/examples_gen.go b/website/examples_gen.go index 5e19bd30..06570355 100644 --- a/website/examples_gen.go +++ b/website/examples_gen.go @@ -348,7 +348,7 @@ func toMarkdown(destination string, cases []testcase) { } if c.doc == "" { - c.doc = strings.Title(c.name) //nolint:staticcheck + c.doc = strings.Title(c.name) } // write the sql query fmt.Fprintf(buf, "## %s\n\nSQL:\n\n```sql\n%s\n```\n\n", c.doc, c.query) @@ -380,7 +380,7 @@ func toMarkdown(destination string, cases []testcase) { func markdownTitle(s string) string { base := filepath.Base(s) heading := strings.TrimSuffix(base, filepath.Ext(base)) - return fmt.Sprintf("# %s\n\n", strings.Title(heading)) //nolint:staticcheck + return fmt.Sprintf("# %s\n\n", strings.Title(heading)) } func reindent(s string) string {