Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cmd/gosqlx/cmd/sql_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,14 @@ func (f *SQLFormatter) formatInsert(stmt *ast.InsertStatement) error {
if len(stmt.Values) > 0 {
f.writeNewline()
f.writeKeyword("VALUES")
f.builder.WriteString(" (")
f.formatExpressionList(stmt.Values, ", ")
f.builder.WriteString(")")
for i, row := range stmt.Values {
if i > 0 {
f.builder.WriteString(",")
}
f.builder.WriteString(" (")
f.formatExpressionList(row, ", ")
f.builder.WriteString(")")
}
}

if stmt.Query != nil {
Expand Down
6 changes: 4 additions & 2 deletions pkg/gosqlx/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,10 @@ func (fc *functionCollector) collectFromNode(node ast.Node) {
fc.collectFromNode(n.With)
}
case *ast.InsertStatement:
for _, val := range n.Values {
fc.collectFromExpression(val)
for _, row := range n.Values {
for _, val := range row {
fc.collectFromExpression(val)
}
}
if n.Query != nil {
fc.collectFromNode(n.Query)
Expand Down
4 changes: 4 additions & 0 deletions pkg/models/token_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ const (
TokenTypeCube TokenType = 392
TokenTypeGrouping TokenType = 393
TokenTypeSets TokenType = 394 // SETS keyword for GROUPING SETS
TokenTypeArray TokenType = 395 // ARRAY keyword for PostgreSQL array constructor
TokenTypeWithin TokenType = 396 // WITHIN keyword for WITHIN GROUP clause

// Role/Permission Keywords (400-419)
TokenTypeRole TokenType = 400
Expand Down Expand Up @@ -620,6 +622,8 @@ var tokenStringMap = map[TokenType]string{
TokenTypeCube: "CUBE",
TokenTypeGrouping: "GROUPING",
TokenTypeSets: "SETS",
TokenTypeArray: "ARRAY",
TokenTypeWithin: "WITHIN",

// Role/Permission Keywords
TokenTypeRole: "ROLE",
Expand Down
57 changes: 49 additions & 8 deletions pkg/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,13 +602,15 @@ func (i Identifier) Children() []Node { return nil }
// New in v1.6.0:
// - Filter: FILTER clause for conditional aggregation
// - OrderBy: ORDER BY clause for order-sensitive aggregates (STRING_AGG, ARRAY_AGG, etc.)
// - WithinGroup: ORDER BY clause for ordered-set aggregates (PERCENTILE_CONT, PERCENTILE_DISC, MODE, etc.)
type FunctionCall struct {
Name string
Arguments []Expression // Renamed from Args for consistency
Over *WindowSpec // For window functions
Distinct bool
Filter Expression // WHERE clause for aggregate functions
OrderBy []OrderByExpression // ORDER BY clause for aggregate functions (STRING_AGG, ARRAY_AGG, etc.)
Name string
Arguments []Expression // Renamed from Args for consistency
Over *WindowSpec // For window functions
Distinct bool
Filter Expression // WHERE clause for aggregate functions
OrderBy []OrderByExpression // ORDER BY clause for aggregate functions (STRING_AGG, ARRAY_AGG, etc.)
WithinGroup []OrderByExpression // ORDER BY clause for ordered-set aggregates (PERCENTILE_CONT, etc.)
}

func (f *FunctionCall) expressionNode() {}
Expand All @@ -625,6 +627,10 @@ func (f FunctionCall) Children() []Node {
orderBy := orderBy // G601: Create local copy to avoid memory aliasing
children = append(children, &orderBy)
}
for _, orderBy := range f.WithinGroup {
orderBy := orderBy // G601: Create local copy to avoid memory aliasing
children = append(children, &orderBy)
}
return children
}

Expand Down Expand Up @@ -887,6 +893,38 @@ func (l *ListExpression) expressionNode() {}
func (l ListExpression) TokenLiteral() string { return "LIST" }
func (l ListExpression) Children() []Node { return nodifyExpressions(l.Values) }

// TupleExpression represents a row constructor / tuple (col1, col2) for multi-column comparisons
// Used in: WHERE (user_id, status) IN ((1, 'active'), (2, 'pending'))
type TupleExpression struct {
Expressions []Expression
}

func (t *TupleExpression) expressionNode() {}
func (t TupleExpression) TokenLiteral() string { return "TUPLE" }
func (t TupleExpression) Children() []Node { return nodifyExpressions(t.Expressions) }

// ArrayConstructorExpression represents PostgreSQL ARRAY constructor syntax.
// Creates an array value from a list of expressions or a subquery.
//
// Examples:
//
// ARRAY[1, 2, 3] - Integer array literal
// ARRAY['admin', 'moderator'] - Text array literal
// ARRAY(SELECT id FROM users) - Array from subquery
type ArrayConstructorExpression struct {
Elements []Expression // Elements inside ARRAY[...]
Subquery *SelectStatement // For ARRAY(SELECT ...) syntax (optional)
}

func (a *ArrayConstructorExpression) expressionNode() {}
func (a ArrayConstructorExpression) TokenLiteral() string { return "ARRAY" }
func (a ArrayConstructorExpression) Children() []Node {
if a.Subquery != nil {
return []Node{a.Subquery}
}
return nodifyExpressions(a.Elements)
}

// UnaryExpression represents operations like NOT expr
type UnaryExpression struct {
Operator UnaryOperator
Expand Down Expand Up @@ -963,7 +1001,7 @@ type InsertStatement struct {
With *WithClause
TableName string
Columns []Expression
Values []Expression
Values [][]Expression // Multi-row support: each inner slice is one row of values
Query *SelectStatement // For INSERT ... SELECT
Returning []Expression
OnConflict *OnConflict
Expand All @@ -978,7 +1016,10 @@ func (i InsertStatement) Children() []Node {
children = append(children, i.With)
}
children = append(children, nodifyExpressions(i.Columns)...)
children = append(children, nodifyExpressions(i.Values)...)
// Flatten multi-row values for Children()
for _, row := range i.Values {
children = append(children, nodifyExpressions(row)...)
}
if i.Query != nil {
children = append(children, i.Query)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/ast/coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,10 @@ func TestSpanMethods(t *testing.T) {
&Identifier{Name: "id"},
&Identifier{Name: "name"},
},
Values: []Expression{
Values: [][]Expression{{
&LiteralValue{Value: 1},
&LiteralValue{Value: "test"},
},
}},
}
span := insert.Span()
// Should return combined span of components
Expand Down Expand Up @@ -830,7 +830,7 @@ func TestInsertStatementChildrenCoverage(t *testing.T) {
},
TableName: "users",
Columns: []Expression{&Identifier{Name: "id"}},
Values: []Expression{&LiteralValue{Value: 1}},
Values: [][]Expression{{&LiteralValue{Value: 1}}},
Query: &SelectStatement{},
Returning: []Expression{&Identifier{Name: "id"}},
OnConflict: &OnConflict{
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/ast/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ func TestInsertStatementChildren(t *testing.T) {
stmt := &InsertStatement{
With: &WithClause{},
Columns: []Expression{testIdent},
Values: []Expression{testExpr},
Values: [][]Expression{{testExpr}},
Query: &SelectStatement{},
Returning: []Expression{testIdent},
OnConflict: &OnConflict{},
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/ast/nodes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ func TestInsertStatement(t *testing.T) {
stmt: &InsertStatement{
TableName: "users",
Columns: []Expression{&Identifier{Name: "name"}, &Identifier{Name: "email"}},
Values: []Expression{&LiteralValue{Value: "John"}, &LiteralValue{Value: "[email protected]"}},
Values: [][]Expression{{&LiteralValue{Value: "John"}, &LiteralValue{Value: "[email protected]"}}},
},
wantLiteral: "INSERT",
minChildren: 2,
Expand Down
89 changes: 86 additions & 3 deletions pkg/sql/ast/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ var (
New: func() interface{} {
return &InsertStatement{
Columns: make([]Expression, 0, 4),
Values: make([]Expression, 0, 4),
Values: make([][]Expression, 0, 4),
}
},
}
Expand Down Expand Up @@ -129,6 +129,22 @@ var (
},
}

tupleExprPool = sync.Pool{
New: func() interface{} {
return &TupleExpression{
Expressions: make([]Expression, 0, 4),
}
},
}

arrayConstructorPool = sync.Pool{
New: func() interface{} {
return &ArrayConstructorExpression{
Elements: make([]Expression, 0, 4),
}
},
}

subqueryExprPool = sync.Pool{
New: func() interface{} {
return &SubqueryExpression{}
Expand Down Expand Up @@ -337,9 +353,13 @@ func PutInsertStatement(stmt *InsertStatement) {
PutExpression(stmt.Columns[i])
stmt.Columns[i] = nil
}
// Clean up multi-row values
for i := range stmt.Values {
PutExpression(stmt.Values[i])
stmt.Values[i] = nil
for j := range stmt.Values[i] {
PutExpression(stmt.Values[i][j])
stmt.Values[i][j] = nil
}
stmt.Values[i] = stmt.Values[i][:0]
}

// Reset slices but keep capacity
Expand Down Expand Up @@ -784,6 +804,27 @@ func PutExpression(expr Expression) {
e.Values = e.Values[:0]
listExprPool.Put(e)

case *TupleExpression:
for i := range e.Expressions {
if e.Expressions[i] != nil {
workQueue = append(workQueue, e.Expressions[i])
}
e.Expressions[i] = nil
}
e.Expressions = e.Expressions[:0]
tupleExprPool.Put(e)

case *ArrayConstructorExpression:
for i := range e.Elements {
if e.Elements[i] != nil {
workQueue = append(workQueue, e.Elements[i])
}
e.Elements[i] = nil
}
e.Elements = e.Elements[:0]
e.Subquery = nil
arrayConstructorPool.Put(e)

case *UnaryExpression:
if e.Expr != nil {
workQueue = append(workQueue, e.Expr)
Expand Down Expand Up @@ -933,6 +974,48 @@ func PutInExpression(ie *InExpression) {
inExprPool.Put(ie)
}

// GetTupleExpression gets a TupleExpression from the pool
func GetTupleExpression() *TupleExpression {
te := tupleExprPool.Get().(*TupleExpression)
te.Expressions = te.Expressions[:0]
return te
}

// PutTupleExpression returns a TupleExpression to the pool
func PutTupleExpression(te *TupleExpression) {
if te == nil {
return
}
for i := range te.Expressions {
PutExpression(te.Expressions[i])
te.Expressions[i] = nil
}
te.Expressions = te.Expressions[:0]
tupleExprPool.Put(te)
}

// GetArrayConstructor gets an ArrayConstructorExpression from the pool
func GetArrayConstructor() *ArrayConstructorExpression {
ac := arrayConstructorPool.Get().(*ArrayConstructorExpression)
ac.Elements = ac.Elements[:0]
ac.Subquery = nil
return ac
}

// PutArrayConstructor returns an ArrayConstructorExpression to the pool
func PutArrayConstructor(ac *ArrayConstructorExpression) {
if ac == nil {
return
}
for i := range ac.Elements {
PutExpression(ac.Elements[i])
ac.Elements[i] = nil
}
ac.Elements = ac.Elements[:0]
ac.Subquery = nil
arrayConstructorPool.Put(ac)
}

// GetSubqueryExpression gets a SubqueryExpression from the pool
func GetSubqueryExpression() *SubqueryExpression {
return subqueryExprPool.Get().(*SubqueryExpression)
Expand Down
12 changes: 8 additions & 4 deletions pkg/sql/ast/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ func TestInsertStatementPool(t *testing.T) {
&Identifier{Name: "name"},
&Identifier{Name: "email"},
}
stmt.Values = []Expression{
&LiteralValue{Value: "John"},
&LiteralValue{Value: "[email protected]"},
// Values is now [][]Expression for multi-row support
stmt.Values = [][]Expression{
{
&LiteralValue{Value: "John"},
&LiteralValue{Value: "[email protected]"},
},
}

// Return to pool
Expand Down Expand Up @@ -371,7 +374,8 @@ func TestMemoryLeaks_InsertStatementPool(t *testing.T) {

stmt.TableName = "users"
stmt.Columns = append(stmt.Columns, &Identifier{Name: "name"}, &Identifier{Name: "email"})
stmt.Values = append(stmt.Values, &LiteralValue{Value: "John"}, &LiteralValue{Value: "[email protected]"})
// Values is now [][]Expression for multi-row support
stmt.Values = append(stmt.Values, []Expression{&LiteralValue{Value: "John"}, &LiteralValue{Value: "[email protected]"}})

PutInsertStatement(stmt)

Expand Down
8 changes: 5 additions & 3 deletions pkg/sql/ast/span.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ func (i *InsertStatement) Span() models.Span {
}
}

for _, val := range i.Values {
if spanned, ok := val.(Spanned); ok {
spans = append(spans, spanned.Span())
for _, row := range i.Values {
for _, val := range row {
if spanned, ok := val.(Spanned); ok {
spans = append(spans, spanned.Span())
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/ast/span_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func TestAST_Span(t *testing.T) {

stmt2 := &InsertStatement{
Columns: []Expression{},
Values: []Expression{},
Values: [][]Expression{},
}
SetSpan(stmt2, models.Span{
Start: models.Location{Line: 3, Column: 1},
Expand Down Expand Up @@ -239,7 +239,7 @@ func TestInsertStatement_Span(t *testing.T) {

stmt := &InsertStatement{
Columns: []Expression{col},
Values: []Expression{val},
Values: [][]Expression{{val}},
}

// Just call Span() to ensure it works
Expand All @@ -255,7 +255,7 @@ func TestInsertStatement_Span(t *testing.T) {

stmt := &InsertStatement{
Columns: []Expression{},
Values: []Expression{},
Values: [][]Expression{},
Returning: []Expression{ret},
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/keywords/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ var ADDITIONAL_KEYWORDS = []Keyword{
{Word: "SETS", Type: models.TokenTypeKeyword, Reserved: true, ReservedForTableAlias: false},
// FILTER clause for aggregate functions (SQL:2003 T612)
{Word: "FILTER", Type: models.TokenTypeFilter, Reserved: true, ReservedForTableAlias: false},
// ARRAY constructor (SQL-99, PostgreSQL)
{Word: "ARRAY", Type: models.TokenTypeArray, Reserved: true, ReservedForTableAlias: false},
// WITHIN GROUP ordered set aggregates (SQL:2003)
{Word: "WITHIN", Type: models.TokenTypeWithin, Reserved: true, ReservedForTableAlias: false},
// MERGE statement keywords (SQL:2003 F312)
{Word: "MERGE", Type: models.TokenTypeKeyword, Reserved: true, ReservedForTableAlias: true},
{Word: "USING", Type: models.TokenTypeKeyword, Reserved: true, ReservedForTableAlias: true},
Expand Down
Loading
Loading