Skip to content

Commit

Permalink
Merge pull request #268 from mbezhanov/collate-fix
Browse files Browse the repository at this point in the history
Fix invalid SQL generated by OrderBy().Collate()
  • Loading branch information
stephenafamo authored Aug 19, 2024
2 parents b1a4b01 + e62bdec commit 22be31a
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed unnecessary import of `strings` in `bobfactory_random.go`.
- Fixed data races in unit tests.
- Fixed invalid SQL statements generated by `sm.OrderBy().Collate()`.

## [v0.28.1] - 2024-06-28

Expand Down
20 changes: 11 additions & 9 deletions clause/order_by.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,25 @@ func (o OrderBy) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error)
}

type OrderDef struct {
Expression any
Direction string // ASC | DESC | USING operator
Nulls string // FIRST | LAST
CollationName string
Expression any
Direction string // ASC | DESC | USING operator
Nulls string // FIRST | LAST
Collation bob.Expression
}

func (o OrderDef) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
if o.CollationName != "" {
w.Write([]byte("COLLATE "))
w.Write([]byte(o.CollationName))
}

args, err := bob.Express(w, d, start, o.Expression)
if err != nil {
return nil, err
}

if o.Collation != nil {
_, err = o.Collation.WriteSQL(w, d, start)
if err != nil {
return nil, err
}
}

if o.Direction != "" {
w.Write([]byte(" "))
w.Write([]byte(o.Direction))
Expand Down
16 changes: 14 additions & 2 deletions dialect/mysql/dialect/mods.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dialect

import (
"fmt"
"io"

"github.com/stephenafamo/bob"
Expand Down Expand Up @@ -259,6 +260,17 @@ func (j JoinChain[Q]) Using(using ...string) bob.Mod[Q] {
return mods.Join[Q](jo)
}

type collation struct {
name string
}

func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) {
if _, err := fmt.Fprintf(w, " COLLATE %s", c.name); err != nil {
return nil, err
}
return nil, nil
}

type OrderBy[Q interface{ AppendOrder(clause.OrderDef) }] func() clause.OrderDef

func (s OrderBy[Q]) Apply(q Q) {
Expand All @@ -283,9 +295,9 @@ func (o OrderBy[Q]) Desc() OrderBy[Q] {
})
}

func (o OrderBy[Q]) Collate(collation string) OrderBy[Q] {
func (o OrderBy[Q]) Collate(collationName string) OrderBy[Q] {
order := o()
order.CollationName = collation
order.Collation = collation{name: collationName}

return OrderBy[Q](func() clause.OrderDef {
return order
Expand Down
8 changes: 8 additions & 0 deletions dialect/mysql/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ func TestSelect(t *testing.T) {
ExpectedSQL: "SELECT id, name FROM users WHERE ((`id`, `employee_id`) IN ((?, ?), (?, ?)))",
ExpectedArgs: []any{100, 200, 300, 400},
},
"select with order by and collate": {
Query: mysql.Select(
sm.Columns("id", "name"),
sm.From("users"),
sm.OrderBy("name").Collate("utf8mb4_bg_0900_as_cs").Asc(),
),
ExpectedSQL: "SELECT id, name FROM users ORDER BY name COLLATE utf8mb4_bg_0900_as_cs ASC",
},
}

testutils.RunTests(t, examples, formatter)
Expand Down
16 changes: 14 additions & 2 deletions dialect/psql/dialect/mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,18 @@ func (j JoinChain[Q]) Using(using ...string) bob.Mod[Q] {
return mods.Join[Q](jo)
}

type collation struct {
name string
}

func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) {
if _, err := w.Write([]byte(" COLLATE ")); err != nil {
return nil, err
}
d.WriteQuoted(w, c.name)
return nil, nil
}

type OrderBy[Q interface{ AppendOrder(clause.OrderDef) }] func() clause.OrderDef

func (s OrderBy[Q]) Apply(q Q) {
Expand Down Expand Up @@ -248,9 +260,9 @@ func (o OrderBy[Q]) NullsLast() OrderBy[Q] {
})
}

func (o OrderBy[Q]) Collate(collation string) OrderBy[Q] {
func (o OrderBy[Q]) Collate(collationName string) OrderBy[Q] {
order := o()
order.CollationName = collation
order.Collation = collation{name: collationName}

return OrderBy[Q](func() clause.OrderDef {
return order
Expand Down
8 changes: 8 additions & 0 deletions dialect/psql/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ WINDOW w AS (PARTITION BY depname ORDER BY salary)`,
sm.Window("w").PartitionBy("depname").OrderBy("salary"),
),
},
"select with order by and collate": {
Query: psql.Select(
sm.Columns("id", "name"),
sm.From("users"),
sm.OrderBy("name").Collate("bg-BG-x-icu").Asc(),
),
ExpectedSQL: `SELECT id, name FROM users ORDER BY name COLLATE "bg-BG-x-icu" ASC`,
},
}

testutils.RunTests(t, examples, formatter)
Expand Down
16 changes: 14 additions & 2 deletions dialect/sqlite/dialect/mods.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dialect

import (
"fmt"
"io"

"github.com/stephenafamo/bob"
Expand Down Expand Up @@ -228,15 +229,26 @@ func CrossJoin[Q Joinable](e any) bob.Mod[Q] {
return Join[Q](clause.CrossJoin, e)
}

type collation struct {
name string
}

func (c collation) WriteSQL(w io.Writer, d bob.Dialect, _ int) ([]any, error) {
if _, err := fmt.Fprintf(w, " COLLATE %s", c.name); err != nil {
return nil, err
}
return nil, nil
}

type OrderBy[Q interface{ AppendOrder(clause.OrderDef) }] func() clause.OrderDef

func (s OrderBy[Q]) Apply(q Q) {
q.AppendOrder(s())
}

func (o OrderBy[Q]) Collate(collation string) OrderBy[Q] {
func (o OrderBy[Q]) Collate(collationName string) OrderBy[Q] {
order := o()
order.CollationName = collation
order.Collation = &collation{name: collationName}

return OrderBy[Q](func() clause.OrderDef {
return order
Expand Down
8 changes: 8 additions & 0 deletions dialect/sqlite/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ func TestSelect(t *testing.T) {
ExpectedSQL: `SELECT id, name FROM users WHERE (("id", "employee_id") IN ((?1, ?2), (?3, ?4)))`,
ExpectedArgs: []any{100, 200, 300, 400},
},
"select with order by and collate": {
Query: sqlite.Select(
sm.Columns("id", "name"),
sm.From("users"),
sm.OrderBy("name").Collate("NOCASE").Asc(),
),
ExpectedSQL: `SELECT id, name FROM users ORDER BY name COLLATE NOCASE ASC`,
},
}

testutils.RunTests(t, examples, formatter)
Expand Down
4 changes: 2 additions & 2 deletions gen/output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ func TestGetOutputFilename(t *testing.T) {
t.Run(name, func(t *testing.T) {
notTest := getOutputFilename(tc.SchemaName, tc.TableName, false, tc.IsGo)
if diff := cmp.Diff(tc.Expected, notTest); diff != "" {
t.Fatalf(diff)
t.Fatal(diff)
}

isTest := getOutputFilename(tc.SchemaName, tc.TableName, true, tc.IsGo)
if diff := cmp.Diff(tc.Expected+"_test", isTest); diff != "" {
t.Fatalf(diff)
t.Fatal(diff)
}
})
}
Expand Down
4 changes: 2 additions & 2 deletions website/examples_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ func toMarkdown(destination string, cases []testcase) {
}

if c.doc == "" {
c.doc = strings.Title(c.name) //nolint:staticcheck
c.doc = strings.Title(c.name)
}
// write the sql query
fmt.Fprintf(buf, "## %s\n\nSQL:\n\n```sql\n%s\n```\n\n", c.doc, c.query)
Expand Down Expand Up @@ -380,7 +380,7 @@ func toMarkdown(destination string, cases []testcase) {
func markdownTitle(s string) string {
base := filepath.Base(s)
heading := strings.TrimSuffix(base, filepath.Ext(base))
return fmt.Sprintf("# %s\n\n", strings.Title(heading)) //nolint:staticcheck
return fmt.Sprintf("# %s\n\n", strings.Title(heading))
}

func reindent(s string) string {
Expand Down

0 comments on commit 22be31a

Please sign in to comment.