Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context to WriteSQL and WriteQuery #283

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added error constants for matching against both specific and generic unique constraint errors raised by the underlying database driver. (thanks @mbezhanov)
- Added support for regular expressions in the `only` and `except` table filters. (thanks @mbezhanov)

### Changed

- `context.Context` is now passed to `Query.WriteQuery()` and `Expression.WriteSQL()` methods. This allows for more control over how the query is built and executed.
This change made is possible to delete some hacks and simplify the codebase.
- The `Name()` and `NameAs()` methods of Views/Tables no longer need the context argument since the context will be passed when writing the expression. The API then becomes cleaner.
- Preloading mods no longer need to store a context internally. `SetLoadContext()` and `GetLoadContext()` have removed.
- The `ToExpr` field in `orm.RelSide` which was used for preloading is no longer needed and has been removed.

### Removed

- Remove MS SQL artifacts. (thanks @mbezhanov)
Expand Down
21 changes: 12 additions & 9 deletions build.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package bob

import "bytes"
import (
"bytes"
"context"
)

// MustBuild builds a query and panics on error
// useful for initializing queries that need to be reused
func MustBuild(q Query) (string, []any) {
return MustBuildN(q, 1)
func MustBuild(ctx context.Context, q Query) (string, []any) {
return MustBuildN(ctx, q, 1)
}

func MustBuildN(q Query, start int) (string, []any) {
sql, args, err := BuildN(q, start)
func MustBuildN(ctx context.Context, q Query, start int) (string, []any) {
sql, args, err := BuildN(ctx, q, start)
if err != nil {
panic(err)
}
Expand All @@ -18,14 +21,14 @@ func MustBuildN(q Query, start int) (string, []any) {
}

// Convinient function to build query from start
func Build(q Query) (string, []any, error) {
return BuildN(q, 1)
func Build(ctx context.Context, q Query) (string, []any, error) {
return BuildN(ctx, q, 1)
}

// Convinient function to build query from a point
func BuildN(q Query, start int) (string, []any, error) {
func BuildN(ctx context.Context, q Query, start int) (string, []any, error) {
b := &bytes.Buffer{}
args, err := q.WriteQuery(b, start)
args, err := q.WriteQuery(ctx, b, start)

return b.String(), args, err
}
11 changes: 6 additions & 5 deletions cached.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package bob

import (
"context"
"fmt"
"io"
)

func Cache(q Query) (BaseQuery[*cached], error) {
return CacheN(q, 1)
func Cache(ctx context.Context, q Query) (BaseQuery[*cached], error) {
return CacheN(ctx, q, 1)
}

func CacheN(q Query, start int) (BaseQuery[*cached], error) {
query, args, err := BuildN(q, start)
func CacheN(ctx context.Context, q Query, start int) (BaseQuery[*cached], error) {
query, args, err := BuildN(ctx, q, start)
if err != nil {
return BaseQuery[*cached]{}, err
}
Expand Down Expand Up @@ -40,7 +41,7 @@ type cached struct {
}

// WriteSQL implements Expression.
func (c *cached) WriteSQL(w io.Writer, d Dialect, start int) ([]any, error) {
func (c *cached) WriteSQL(ctx context.Context, w io.Writer, d Dialect, start int) ([]any, error) {
if start != c.start {
return nil, WrongStartError{Expected: c.start, Got: start}
}
Expand Down
5 changes: 3 additions & 2 deletions clause/combine.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clause

import (
"context"
"errors"
"io"

Expand All @@ -25,7 +26,7 @@ func (s *Combine) SetCombine(c Combine) {
*s = c
}

func (s Combine) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (s Combine) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
if s.Strategy == "" {
return nil, ErrNoCombinationStrategy
}
Expand All @@ -38,7 +39,7 @@ func (s Combine) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error)
w.Write([]byte(" "))
}

args, err := bob.Express(w, d, start, s.Query)
args, err := bob.Express(ctx, w, d, start, s.Query)
if err != nil {
return nil, err
}
Expand Down
17 changes: 9 additions & 8 deletions clause/conflict.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clause

import (
"context"
"io"

"github.com/stephenafamo/bob"
Expand All @@ -17,24 +18,24 @@ func (c *Conflict) SetConflict(conflict Conflict) {
*c = conflict
}

func (c Conflict) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (c Conflict) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
w.Write([]byte("ON CONFLICT"))

args, err := bob.ExpressIf(w, d, start, c.Target, true, "", "")
args, err := bob.ExpressIf(ctx, w, d, start, c.Target, true, "", "")
if err != nil {
return nil, err
}

w.Write([]byte(" DO "))
w.Write([]byte(c.Do))

setArgs, err := bob.ExpressIf(w, d, start+len(args), c.Set, len(c.Set.Set) > 0, " SET\n", "")
setArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Set, len(c.Set.Set) > 0, " SET\n", "")
if err != nil {
return nil, err
}
args = append(args, setArgs...)

whereArgs, err := bob.ExpressIf(w, d, start+len(args), c.Where,
whereArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Where,
len(c.Where.Conditions) > 0, "\n", "")
if err != nil {
return nil, err
Expand All @@ -50,17 +51,17 @@ type ConflictTarget struct {
Where []any
}

func (c ConflictTarget) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (c ConflictTarget) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
if c.Constraint != "" {
return bob.ExpressIf(w, d, start, c.Constraint, true, " ON CONSTRAINT ", "")
return bob.ExpressIf(ctx, w, d, start, c.Constraint, true, " ON CONSTRAINT ", "")
}

args, err := bob.ExpressSlice(w, d, start, c.Columns, " (", ", ", ")")
args, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, " (", ", ", ")")
if err != nil {
return nil, err
}

whereArgs, err := bob.ExpressSlice(w, d, start+len(args), c.Where, " WHERE ", " AND ", "")
whereArgs, err := bob.ExpressSlice(ctx, w, d, start+len(args), c.Where, " WHERE ", " AND ", "")
if err != nil {
return nil, err
}
Expand Down
23 changes: 12 additions & 11 deletions clause/cte.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clause

import (
"context"
"fmt"
"io"

Expand All @@ -16,9 +17,9 @@ type CTE struct {
Cycle CTECycle
}

func (c CTE) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (c CTE) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
w.Write([]byte(c.Name))
_, err := bob.ExpressSlice(w, d, start, c.Columns, "(", ", ", ")")
_, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, "(", ", ", ")")
if err != nil {
return nil, err
}
Expand All @@ -36,20 +37,20 @@ func (c CTE) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
}

w.Write([]byte("("))
args, err := c.Query.WriteQuery(w, start)
args, err := c.Query.WriteQuery(ctx, w, start)
if err != nil {
return nil, err
}
w.Write([]byte(")"))

searchArgs, err := bob.ExpressIf(w, d, start+len(args), c.Search,
searchArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Search,
len(c.Search.Columns) > 0, "\n", "")
if err != nil {
return nil, err
}
args = append(args, searchArgs...)

cycleArgs, err := bob.ExpressIf(w, d, start+len(args), c.Cycle,
cycleArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.Cycle,
len(c.Cycle.Columns) > 0, "\n", "")
if err != nil {
return nil, err
Expand All @@ -70,11 +71,11 @@ type CTESearch struct {
Set string
}

func (c CTESearch) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (c CTESearch) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
// [ SEARCH { BREADTH | DEPTH } FIRST BY column_name [, ...] SET search_seq_col_name ]
fmt.Fprintf(w, "SEARCH %s FIRST BY ", c.Order)

args, err := bob.ExpressSlice(w, d, start, c.Columns, "", ", ", "")
args, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, "", ", ", "")
if err != nil {
return nil, err
}
Expand All @@ -92,25 +93,25 @@ type CTECycle struct {
DefaultVal any
}

func (c CTECycle) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (c CTECycle) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
//[ CYCLE column_name [, ...] SET cycle_mark_col_name [ TO cycle_mark_value DEFAULT cycle_mark_default ] USING cycle_path_col_name ]
w.Write([]byte("CYCLE "))

args, err := bob.ExpressSlice(w, d, start, c.Columns, "", ", ", "")
args, err := bob.ExpressSlice(ctx, w, d, start, c.Columns, "", ", ", "")
if err != nil {
return nil, err
}

fmt.Fprintf(w, " SET %s", c.Set)

markArgs, err := bob.ExpressIf(w, d, start+len(args), c.SetVal,
markArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.SetVal,
c.SetVal != nil, " TO ", "")
if err != nil {
return nil, err
}
args = append(args, markArgs...)

defaultArgs, err := bob.ExpressIf(w, d, start+len(args), c.DefaultVal,
defaultArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), c.DefaultVal,
c.DefaultVal != nil, " DEFAULT ", "")
if err != nil {
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion clause/fetch.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clause

import (
"context"
"io"
"strconv"

Expand All @@ -16,7 +17,7 @@ func (f *Fetch) SetFetch(fetch Fetch) {
*f = fetch
}

func (f Fetch) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (f Fetch) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
if f.Count == nil {
return nil, nil
}
Expand Down
5 changes: 3 additions & 2 deletions clause/for.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clause

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -32,7 +33,7 @@ func (f *For) SetFor(lock For) {
*f = lock
}

func (f For) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (f For) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
if f.Strength == "" {
return nil, nil
}
Expand All @@ -42,7 +43,7 @@ func (f For) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
fmt.Fprintf(w, "%s ", f.Strength)
}

args, err := bob.ExpressSlice(w, d, start, f.Tables, "OF ", ", ", "")
args, err := bob.ExpressSlice(ctx, w, d, start, f.Tables, "OF ", ", ", "")
if err != nil {
return nil, err
}
Expand Down
9 changes: 5 additions & 4 deletions clause/frame.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clause

import (
"context"
"io"

"github.com/stephenafamo/bob"
Expand Down Expand Up @@ -34,7 +35,7 @@ func (f *Frame) SetExclusion(excl string) {
f.Exclusion = excl
}

func (f Frame) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
func (f Frame) WriteSQL(ctx context.Context, w io.Writer, d bob.Dialect, start int) ([]any, error) {
if f.Mode == "" {
f.Mode = "RANGE"
}
Expand All @@ -52,19 +53,19 @@ func (f Frame) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
w.Write([]byte("BETWEEN "))
}

startArgs, err := bob.Express(w, d, start, f.Start)
startArgs, err := bob.Express(ctx, w, d, start, f.Start)
if err != nil {
return nil, err
}
args = append(args, startArgs...)

endArgs, err := bob.ExpressIf(w, d, start, f.End, f.End != nil, " AND ", "")
endArgs, err := bob.ExpressIf(ctx, w, d, start, f.End, f.End != nil, " AND ", "")
if err != nil {
return nil, err
}
args = append(args, endArgs...)

_, err = bob.ExpressIf(w, d, start, f.Exclusion, f.Exclusion != "", " EXCLUDE ", "")
_, err = bob.ExpressIf(ctx, w, d, start, f.Exclusion, f.Exclusion != "", " EXCLUDE ", "")
if err != nil {
return nil, err
}
Expand Down
Loading
Loading