diff --git a/pkg/gosqlx/between_complex_test.go b/pkg/gosqlx/between_complex_test.go new file mode 100644 index 0000000..839e0ef --- /dev/null +++ b/pkg/gosqlx/between_complex_test.go @@ -0,0 +1,334 @@ +// Package gosqlx - between_complex_test.go +// End-to-end tests for BETWEEN with complex expressions using high-level API (Issue #180) + +package gosqlx + +import ( + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" +) + +// TestParse_BetweenWithArithmeticExpressions tests BETWEEN with arithmetic via high-level API +func TestParse_BetweenWithArithmeticExpressions(t *testing.T) { + sql := "SELECT * FROM products WHERE price BETWEEN price * 0.9 AND price * 1.1" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(astObj.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(astObj.Statements)) + } + + stmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", astObj.Statements[0]) + } + + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected BetweenExpression, got %T", stmt.Where) + } + + // Verify lower bound is multiplication + lowerBinary, ok := betweenExpr.Lower.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected lower bound to be BinaryExpression, got %T", betweenExpr.Lower) + } + if lowerBinary.Operator != "*" { + t.Errorf("expected operator '*', got '%s'", lowerBinary.Operator) + } + + // Verify upper bound is multiplication + upperBinary, ok := betweenExpr.Upper.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected upper bound to be BinaryExpression, got %T", betweenExpr.Upper) + } + if upperBinary.Operator != "*" { + t.Errorf("expected operator '*', got '%s'", upperBinary.Operator) + } +} + +// TestParse_BetweenWithIntervalArithmetic tests BETWEEN with INTERVAL expressions via high-level API +func TestParse_BetweenWithIntervalArithmetic(t *testing.T) { + sql := "SELECT * FROM orders WHERE created_at BETWEEN NOW() - INTERVAL '30 days' AND NOW()" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + stmt := astObj.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected BetweenExpression, got %T", stmt.Where) + } + + // Verify lower bound is subtraction (NOW() - INTERVAL) + lowerBinary, ok := betweenExpr.Lower.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected lower bound to be BinaryExpression, got %T", betweenExpr.Lower) + } + if lowerBinary.Operator != "-" { + t.Errorf("expected operator '-', got '%s'", lowerBinary.Operator) + } + + // Verify INTERVAL expression in lower bound + intervalExpr, ok := lowerBinary.Right.(*ast.IntervalExpression) + if !ok { + t.Fatalf("expected IntervalExpression, got %T", lowerBinary.Right) + } + if intervalExpr.Value != "30 days" { + t.Errorf("expected interval '30 days', got '%s'", intervalExpr.Value) + } +} + +// TestParse_BetweenWithSubqueries tests BETWEEN with subqueries via high-level API +func TestParse_BetweenWithSubqueries(t *testing.T) { + sql := "SELECT * FROM data WHERE value BETWEEN (SELECT min_val FROM limits) AND (SELECT max_val FROM limits)" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + stmt := astObj.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected BetweenExpression, got %T", stmt.Where) + } + + // Verify both bounds are subqueries + _, ok = betweenExpr.Lower.(*ast.SubqueryExpression) + if !ok { + t.Fatalf("expected lower bound to be SubqueryExpression, got %T", betweenExpr.Lower) + } + + _, ok = betweenExpr.Upper.(*ast.SubqueryExpression) + if !ok { + t.Fatalf("expected upper bound to be SubqueryExpression, got %T", betweenExpr.Upper) + } +} + +// TestParse_BetweenWithFunctionCalls tests BETWEEN with function calls via high-level API +func TestParse_BetweenWithFunctionCalls(t *testing.T) { + sql := "SELECT * FROM orders WHERE amount BETWEEN MIN(price) AND MAX(price) * 2" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + stmt := astObj.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected BetweenExpression, got %T", stmt.Where) + } + + // Verify lower bound is MIN function + lowerFunc, ok := betweenExpr.Lower.(*ast.FunctionCall) + if !ok { + t.Fatalf("expected lower bound to be FunctionCall, got %T", betweenExpr.Lower) + } + if lowerFunc.Name != "MIN" { + t.Errorf("expected function 'MIN', got '%s'", lowerFunc.Name) + } + + // Verify upper bound is arithmetic with MAX function + upperBinary, ok := betweenExpr.Upper.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected upper bound to be BinaryExpression, got %T", betweenExpr.Upper) + } + + upperFunc, ok := upperBinary.Left.(*ast.FunctionCall) + if !ok { + t.Fatalf("expected upper bound left to be FunctionCall, got %T", upperBinary.Left) + } + if upperFunc.Name != "MAX" { + t.Errorf("expected function 'MAX', got '%s'", upperFunc.Name) + } +} + +// TestParse_BetweenComplexScenarios tests various complex BETWEEN scenarios +func TestParse_BetweenComplexScenarios(t *testing.T) { + tests := []struct { + name string + sql string + expectError bool + }{ + { + name: "Arithmetic with addition", + sql: "SELECT * FROM t WHERE x BETWEEN a + b AND c + d", + expectError: false, + }, + { + name: "Arithmetic with subtraction", + sql: "SELECT * FROM t WHERE x BETWEEN a - b AND c - d", + expectError: false, + }, + { + name: "Mixed arithmetic", + sql: "SELECT * FROM t WHERE x BETWEEN a * b + c AND d / e - f", + expectError: false, + }, + { + name: "Nested function calls", + sql: "SELECT * FROM t WHERE x BETWEEN ROUND(AVG(y)) AND CEIL(MAX(z))", + expectError: false, + }, + { + name: "CAST expressions", + sql: "SELECT * FROM t WHERE x BETWEEN CAST(a AS INT) AND CAST(b AS INT)", + expectError: false, + }, + { + name: "String concatenation", + sql: "SELECT * FROM t WHERE x BETWEEN a || 'low' AND b || 'high'", + expectError: false, + }, + { + name: "Parenthesized expressions", + sql: "SELECT * FROM t WHERE x BETWEEN (a * 0.8) + discount AND (b * 1.2) - fee", + expectError: false, + }, + { + name: "INTERVAL arithmetic multiple", + sql: "SELECT * FROM t WHERE ts BETWEEN NOW() - INTERVAL '7 days' AND NOW() - INTERVAL '1 day'", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + astObj, err := Parse(tt.sql) + if tt.expectError { + if err == nil { + t.Error("expected error, got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(astObj.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(astObj.Statements)) + } + + stmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", astObj.Statements[0]) + } + + if stmt.Where == nil { + t.Fatal("expected WHERE clause, got nil") + } + + _, ok = stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected BetweenExpression in WHERE, got %T", stmt.Where) + } + }) + } +} + +// TestParse_BetweenWithNotOperator tests NOT BETWEEN with complex expressions +func TestParse_BetweenWithNotOperator(t *testing.T) { + sql := "SELECT * FROM products WHERE price NOT BETWEEN price * 0.5 AND price * 2" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + stmt := astObj.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected BetweenExpression, got %T", stmt.Where) + } + + if !betweenExpr.Not { + t.Error("expected NOT BETWEEN, but Not flag is false") + } + + // Verify both bounds are arithmetic expressions + _, ok = betweenExpr.Lower.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected lower bound to be BinaryExpression, got %T", betweenExpr.Lower) + } + + _, ok = betweenExpr.Upper.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected upper bound to be BinaryExpression, got %T", betweenExpr.Upper) + } +} + +// TestValidate_BetweenWithComplexExpressions tests Validate function with complex BETWEEN +func TestValidate_BetweenWithComplexExpressions(t *testing.T) { + sqls := []string{ + "SELECT * FROM products WHERE price BETWEEN price * 0.9 AND price * 1.1", + "SELECT * FROM orders WHERE created_at BETWEEN NOW() - INTERVAL '30 days' AND NOW()", + "SELECT * FROM data WHERE value BETWEEN (SELECT MIN(x) FROM limits) AND (SELECT MAX(x) FROM limits)", + } + + for _, sql := range sqls { + t.Run(sql, func(t *testing.T) { + err := Validate(sql) + if err != nil { + t.Errorf("unexpected validation error: %v", err) + } + }) + } +} + +// TestExtractMetadata_BetweenWithComplexExpressions tests ExtractMetadata function with complex BETWEEN +func TestExtractMetadata_BetweenWithComplexExpressions(t *testing.T) { + sql := "SELECT id, name, price FROM products WHERE price BETWEEN price * 0.9 AND price * 1.1 ORDER BY price" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + metadata := ExtractMetadata(astObj) + if metadata == nil { + t.Fatal("expected metadata, got nil") + } + + // Verify tables + if len(metadata.Tables) != 1 { + t.Fatalf("expected 1 table, got %d", len(metadata.Tables)) + } + if metadata.Tables[0] != "products" { + t.Errorf("expected table 'products', got '%s'", metadata.Tables[0]) + } + + // Verify columns (should include id, name, price) + // Price appears multiple times: in SELECT list, WHERE clause, and ORDER BY + expectedColumns := []string{"id", "name", "price"} + for _, col := range expectedColumns { + found := false + for _, metadataCol := range metadata.Columns { + if strings.Contains(strings.ToLower(metadataCol), col) { + found = true + break + } + } + if !found { + t.Errorf("expected column '%s' not found in metadata.Columns: %v", col, metadata.Columns) + } + } + + // Verify the AST structure + stmt := astObj.Statements[0].(*ast.SelectStatement) + if stmt.Where == nil { + t.Error("expected WHERE clause, got nil") + } + if len(stmt.OrderBy) == 0 { + t.Error("expected ORDER BY clause") + } +} diff --git a/pkg/gosqlx/regex_operators_test.go b/pkg/gosqlx/regex_operators_test.go new file mode 100644 index 0000000..ae38980 --- /dev/null +++ b/pkg/gosqlx/regex_operators_test.go @@ -0,0 +1,224 @@ +package gosqlx + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" +) + +// TestRegexOperators_EndToEnd tests PostgreSQL regex operators using the full tokenizer->parser pipeline +// This ensures the entire flow works: tokenizer -> token converter -> parser -> AST +// Issue #190: Support PostgreSQL regular expression operators (~, ~*, !~, !~*) +func TestRegexOperators_EndToEnd(t *testing.T) { + tests := []struct { + name string + sql string + operator string + }{ + { + name: "Case-sensitive regex match (~)", + sql: "SELECT * FROM users WHERE name ~ '^J.*'", + operator: "~", + }, + { + name: "Case-insensitive regex match (~*)", + sql: "SELECT * FROM products WHERE description ~* 'sale|discount'", + operator: "~*", + }, + { + name: "Case-sensitive regex non-match (!~)", + sql: "SELECT * FROM logs WHERE message !~ 'DEBUG'", + operator: "!~", + }, + { + name: "Case-insensitive regex non-match (!~*)", + sql: "SELECT * FROM emails WHERE subject !~* 'spam'", + operator: "!~*", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use Parse which goes through full tokenizer -> parser pipeline + astObj, err := Parse(tt.sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(astObj) + + if len(astObj.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(astObj.Statements)) + } + + stmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", astObj.Statements[0]) + } + + if stmt.Where == nil { + t.Fatal("expected WHERE clause") + } + + binExpr, ok := stmt.Where.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected BinaryExpression, got %T", stmt.Where) + } + + if binExpr.Operator != tt.operator { + t.Errorf("expected operator %q, got %q", tt.operator, binExpr.Operator) + } + + // Verify left side is an identifier + leftIdent, ok := binExpr.Left.(*ast.Identifier) + if !ok { + t.Errorf("expected left side to be Identifier, got %T", binExpr.Left) + } else { + t.Logf("Left identifier: %s", leftIdent.Name) + } + + // Verify right side is a literal (the regex pattern) + rightLit, ok := binExpr.Right.(*ast.LiteralValue) + if !ok { + t.Errorf("expected right side to be LiteralValue, got %T", binExpr.Right) + } else { + t.Logf("Right literal: %v", rightLit.Value) + } + }) + } +} + +// TestRegexOperators_ComplexQueries tests regex operators in complex queries +func TestRegexOperators_ComplexQueries(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "Regex with AND condition", + sql: "SELECT * FROM users WHERE name ~ '^[A-Z]' AND email ~* '@example.com$'", + }, + { + name: "Regex with OR condition", + sql: "SELECT * FROM products WHERE name !~ 'deprecated' OR status = 'active'", + }, + { + name: "Multiple regex operators", + sql: "SELECT * FROM logs WHERE message ~ 'ERROR' AND message !~* 'ignored'", + }, + { + name: "Regex with parentheses", + sql: "SELECT * FROM users WHERE (name ~ '^Admin' OR email ~* '@admin.com') AND status = 'active'", + }, + { + name: "Regex in JOIN condition", + sql: "SELECT * FROM users u JOIN logs l ON l.user_id = u.id WHERE l.message ~ 'ERROR'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + astObj, err := Parse(tt.sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(astObj) + + if len(astObj.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(astObj.Statements)) + } + + stmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", astObj.Statements[0]) + } + + // Just verify we can parse it successfully - structure validation is done in other tests + if stmt.Where == nil { + t.Fatal("expected WHERE clause") + } + + t.Logf("Successfully parsed: %s", tt.sql) + }) + } +} + +// TestRegexOperators_Subqueries tests regex operators in subqueries +func TestRegexOperators_Subqueries(t *testing.T) { + sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM logs WHERE message ~ 'ERROR')" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(astObj) + + if len(astObj.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(astObj.Statements)) + } + + stmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", astObj.Statements[0]) + } + + if stmt.Where == nil { + t.Fatal("expected WHERE clause") + } + + // The WHERE clause should be an IN expression + inExpr, ok := stmt.Where.(*ast.InExpression) + if !ok { + t.Fatalf("expected InExpression, got %T", stmt.Where) + } + + if inExpr.Subquery == nil { + t.Fatal("expected subquery in IN expression") + } + + t.Log("Successfully parsed regex operator in subquery") +} + +// TestRegexOperators_TypeCasting tests regex operators with type casting +func TestRegexOperators_TypeCasting(t *testing.T) { + sql := "SELECT * FROM users WHERE id::text ~ '^[0-9]+$'" + + astObj, err := Parse(sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(astObj) + + if len(astObj.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(astObj.Statements)) + } + + stmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", astObj.Statements[0]) + } + + if stmt.Where == nil { + t.Fatal("expected WHERE clause") + } + + binExpr, ok := stmt.Where.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected BinaryExpression, got %T", stmt.Where) + } + + if binExpr.Operator != "~" { + t.Errorf("expected operator '~', got %q", binExpr.Operator) + } + + // Left side should be a cast expression + castExpr, ok := binExpr.Left.(*ast.CastExpression) + if !ok { + t.Fatalf("expected left side to be CastExpression, got %T", binExpr.Left) + } + + if castExpr.Type != "text" { + t.Errorf("expected cast type 'text', got %q", castExpr.Type) + } + + t.Log("Successfully parsed regex operator with type cast") +} diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index c1bfd80..d7fafa2 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -1038,6 +1038,58 @@ func (i *IntervalExpression) expressionNode() {} func (i IntervalExpression) TokenLiteral() string { return "INTERVAL" } func (i IntervalExpression) Children() []Node { return []Node{} } +// ArraySubscriptExpression represents array element access syntax. +// Supports single and multi-dimensional array subscripting. +// +// Examples: +// +// tags[1] - Single subscript +// matrix[2][3] - Multi-dimensional subscript +// arr[i] - Subscript with variable +// (SELECT arr)[1] - Subscript on subquery result +type ArraySubscriptExpression struct { + Array Expression // The array expression being subscripted + Indices []Expression // Subscript indices (one or more for multi-dimensional arrays) +} + +func (a *ArraySubscriptExpression) expressionNode() {} +func (a ArraySubscriptExpression) TokenLiteral() string { return "[]" } +func (a ArraySubscriptExpression) Children() []Node { + children := []Node{a.Array} + for _, idx := range a.Indices { + children = append(children, idx) + } + return children +} + +// ArraySliceExpression represents array slicing syntax for extracting subarrays. +// Supports PostgreSQL-style array slicing with optional start/end bounds. +// +// Examples: +// +// arr[1:3] - Slice from index 1 to 3 (inclusive) +// arr[2:] - Slice from index 2 to end +// arr[:5] - Slice from start to index 5 +// arr[:] - Full array slice (copy) +type ArraySliceExpression struct { + Array Expression // The array expression being sliced + Start Expression // Start index (nil means from beginning) + End Expression // End index (nil means to end) +} + +func (a *ArraySliceExpression) expressionNode() {} +func (a ArraySliceExpression) TokenLiteral() string { return "[:]" } +func (a ArraySliceExpression) Children() []Node { + children := []Node{a.Array} + if a.Start != nil { + children = append(children, a.Start) + } + if a.End != nil { + children = append(children, a.End) + } + return children +} + // InsertStatement represents an INSERT SQL statement type InsertStatement struct { With *WithClause diff --git a/pkg/sql/ast/pool.go b/pkg/sql/ast/pool.go index d3611f2..2f59744 100644 --- a/pkg/sql/ast/pool.go +++ b/pkg/sql/ast/pool.go @@ -163,6 +163,20 @@ var ( }, } + arraySubscriptExprPool = sync.Pool{ + New: func() interface{} { + return &ArraySubscriptExpression{ + Indices: make([]Expression, 0, 2), // Most common: 1-2 dimensions + } + }, + } + + arraySliceExprPool = sync.Pool{ + New: func() interface{} { + return &ArraySliceExpression{} + }, + } + // Additional expression pools for complete coverage existsExprPool = sync.Pool{ New: func() interface{} { @@ -784,6 +798,34 @@ func PutExpression(expr Expression) { e.Value = "" intervalExprPool.Put(e) + case *ArraySubscriptExpression: + if e.Array != nil { + workQueue = append(workQueue, e.Array) + } + for i := range e.Indices { + if e.Indices[i] != nil { + workQueue = append(workQueue, e.Indices[i]) + } + } + e.Array = nil + e.Indices = e.Indices[:0] + arraySubscriptExprPool.Put(e) + + case *ArraySliceExpression: + if e.Array != nil { + workQueue = append(workQueue, e.Array) + } + if e.Start != nil { + workQueue = append(workQueue, e.Start) + } + if e.End != nil { + workQueue = append(workQueue, e.End) + } + e.Array = nil + e.Start = nil + e.End = nil + arraySliceExprPool.Put(e) + case *ExistsExpression: e.Subquery = nil existsExprPool.Put(e) @@ -1087,3 +1129,55 @@ func PutAliasedExpression(ae *AliasedExpression) { ae.Alias = "" aliasedExprPool.Put(ae) } + +// GetArraySubscriptExpression gets an ArraySubscriptExpression from the pool +func GetArraySubscriptExpression() *ArraySubscriptExpression { + return arraySubscriptExprPool.Get().(*ArraySubscriptExpression) +} + +// PutArraySubscriptExpression returns an ArraySubscriptExpression to the pool +func PutArraySubscriptExpression(ase *ArraySubscriptExpression) { + if ase == nil { + return + } + // Clean up array expression + if ase.Array != nil { + PutExpression(ase.Array) + ase.Array = nil + } + // Clean up indices + for i := range ase.Indices { + if ase.Indices[i] != nil { + PutExpression(ase.Indices[i]) + } + } + ase.Indices = ase.Indices[:0] // Clear slice but keep capacity + arraySubscriptExprPool.Put(ase) +} + +// GetArraySliceExpression gets an ArraySliceExpression from the pool +func GetArraySliceExpression() *ArraySliceExpression { + return arraySliceExprPool.Get().(*ArraySliceExpression) +} + +// PutArraySliceExpression returns an ArraySliceExpression to the pool +func PutArraySliceExpression(ase *ArraySliceExpression) { + if ase == nil { + return + } + // Clean up array expression + if ase.Array != nil { + PutExpression(ase.Array) + ase.Array = nil + } + // Clean up start/end expressions + if ase.Start != nil { + PutExpression(ase.Start) + ase.Start = nil + } + if ase.End != nil { + PutExpression(ase.End) + ase.End = nil + } + arraySliceExprPool.Put(ase) +} diff --git a/pkg/sql/parser/array_subscript_test.go b/pkg/sql/parser/array_subscript_test.go new file mode 100644 index 0000000..9d9425f --- /dev/null +++ b/pkg/sql/parser/array_subscript_test.go @@ -0,0 +1,534 @@ +// Package parser - array_subscript_test.go +// Tests for array subscript and slice syntax (Issue #191) + +package parser + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/tokenizer" +) + +// TestParser_ArraySubscript_Single tests single array subscript +func TestParser_ArraySubscript_Single(t *testing.T) { + sql := "SELECT tags[1] FROM posts" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(parserTokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + + stmt, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + + if len(stmt.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(stmt.Columns)) + } + + // The column should be an ArraySubscriptExpression + subscriptExpr, ok := stmt.Columns[0].(*ast.ArraySubscriptExpression) + if !ok { + t.Fatalf("expected ArraySubscriptExpression, got %T", stmt.Columns[0]) + } + + // Check the array is an identifier "tags" + arrayIdent, ok := subscriptExpr.Array.(*ast.Identifier) + if !ok { + t.Fatalf("expected array to be Identifier, got %T", subscriptExpr.Array) + } + if arrayIdent.Name != "tags" { + t.Errorf("expected array name 'tags', got '%s'", arrayIdent.Name) + } + + // Check we have one index + if len(subscriptExpr.Indices) != 1 { + t.Fatalf("expected 1 index, got %d", len(subscriptExpr.Indices)) + } + + // Check the index is 1 + indexLiteral, ok := subscriptExpr.Indices[0].(*ast.LiteralValue) + if !ok { + t.Fatalf("expected index to be LiteralValue, got %T", subscriptExpr.Indices[0]) + } + if indexLiteral.Value != "1" { + t.Errorf("expected index value '1', got '%v'", indexLiteral.Value) + } +} + +// TestParser_ArraySubscript_MultiDimensional tests multi-dimensional array subscript +func TestParser_ArraySubscript_MultiDimensional(t *testing.T) { + sql := "SELECT matrix[2][3] FROM data" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(parserTokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + + stmt, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + + if len(stmt.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(stmt.Columns)) + } + + // The column should be an ArraySubscriptExpression (outer subscript [3]) + outerSubscript, ok := stmt.Columns[0].(*ast.ArraySubscriptExpression) + if !ok { + t.Fatalf("expected outer ArraySubscriptExpression, got %T", stmt.Columns[0]) + } + + // The array should be another ArraySubscriptExpression (inner subscript [2]) + innerSubscript, ok := outerSubscript.Array.(*ast.ArraySubscriptExpression) + if !ok { + t.Fatalf("expected inner ArraySubscriptExpression, got %T", outerSubscript.Array) + } + + // The innermost array should be an identifier "matrix" + arrayIdent, ok := innerSubscript.Array.(*ast.Identifier) + if !ok { + t.Fatalf("expected array to be Identifier, got %T", innerSubscript.Array) + } + if arrayIdent.Name != "matrix" { + t.Errorf("expected array name 'matrix', got '%s'", arrayIdent.Name) + } +} + +// TestParser_ArraySlice_BothBounds tests array slice with both start and end +func TestParser_ArraySlice_BothBounds(t *testing.T) { + sql := "SELECT tags[1:3] FROM posts" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(parserTokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + + stmt, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + + if len(stmt.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(stmt.Columns)) + } + + // The column should be an ArraySliceExpression + sliceExpr, ok := stmt.Columns[0].(*ast.ArraySliceExpression) + if !ok { + t.Fatalf("expected ArraySliceExpression, got %T", stmt.Columns[0]) + } + + // Check the array is an identifier "tags" + arrayIdent, ok := sliceExpr.Array.(*ast.Identifier) + if !ok { + t.Fatalf("expected array to be Identifier, got %T", sliceExpr.Array) + } + if arrayIdent.Name != "tags" { + t.Errorf("expected array name 'tags', got '%s'", arrayIdent.Name) + } + + // Check start index is 1 + startLiteral, ok := sliceExpr.Start.(*ast.LiteralValue) + if !ok { + t.Fatalf("expected start to be LiteralValue, got %T", sliceExpr.Start) + } + if startLiteral.Value != "1" { + t.Errorf("expected start value '1', got '%v'", startLiteral.Value) + } + + // Check end index is 3 + endLiteral, ok := sliceExpr.End.(*ast.LiteralValue) + if !ok { + t.Fatalf("expected end to be LiteralValue, got %T", sliceExpr.End) + } + if endLiteral.Value != "3" { + t.Errorf("expected end value '3', got '%v'", endLiteral.Value) + } +} + +// TestParser_ArraySlice_StartOnly tests array slice with start only (arr[2:]) +func TestParser_ArraySlice_StartOnly(t *testing.T) { + sql := "SELECT arr[2:] FROM table_name" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(parserTokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + + stmt, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + + if len(stmt.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(stmt.Columns)) + } + + // The column should be an ArraySliceExpression + sliceExpr, ok := stmt.Columns[0].(*ast.ArraySliceExpression) + if !ok { + t.Fatalf("expected ArraySliceExpression, got %T", stmt.Columns[0]) + } + + // Check start index is 2 + startLiteral, ok := sliceExpr.Start.(*ast.LiteralValue) + if !ok { + t.Fatalf("expected start to be LiteralValue, got %T", sliceExpr.Start) + } + if startLiteral.Value != "2" { + t.Errorf("expected start value '2', got '%v'", startLiteral.Value) + } + + // Check end is nil + if sliceExpr.End != nil { + t.Errorf("expected end to be nil, got %T", sliceExpr.End) + } +} + +// TestParser_ArraySlice_EndOnly tests array slice with end only (arr[:5]) +func TestParser_ArraySlice_EndOnly(t *testing.T) { + sql := "SELECT arr[:5] FROM table_name" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(parserTokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + + stmt, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + + if len(stmt.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(stmt.Columns)) + } + + // The column should be an ArraySliceExpression + sliceExpr, ok := stmt.Columns[0].(*ast.ArraySliceExpression) + if !ok { + t.Fatalf("expected ArraySliceExpression, got %T", stmt.Columns[0]) + } + + // Check start is nil + if sliceExpr.Start != nil { + t.Errorf("expected start to be nil, got %T", sliceExpr.Start) + } + + // Check end index is 5 + endLiteral, ok := sliceExpr.End.(*ast.LiteralValue) + if !ok { + t.Fatalf("expected end to be LiteralValue, got %T", sliceExpr.End) + } + if endLiteral.Value != "5" { + t.Errorf("expected end value '5', got '%v'", endLiteral.Value) + } +} + +// TestParser_ArraySubscript_InWhereClause tests array subscript in WHERE clause +func TestParser_ArraySubscript_InWhereClause(t *testing.T) { + sql := "SELECT * FROM posts WHERE tags[1] = 'tech'" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(parserTokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + + stmt, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + + if stmt.Where == nil { + t.Fatal("expected WHERE clause") + } + + // WHERE condition should be: tags[1] = 'tech' + binExpr, ok := stmt.Where.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected WHERE to be BinaryExpression, got %T", stmt.Where) + } + + // Left side should be ArraySubscriptExpression + subscriptExpr, ok := binExpr.Left.(*ast.ArraySubscriptExpression) + if !ok { + t.Fatalf("expected left side to be ArraySubscriptExpression, got %T", binExpr.Left) + } + + // Check the array is "tags" + arrayIdent, ok := subscriptExpr.Array.(*ast.Identifier) + if !ok { + t.Fatalf("expected array to be Identifier, got %T", subscriptExpr.Array) + } + if arrayIdent.Name != "tags" { + t.Errorf("expected array name 'tags', got '%s'", arrayIdent.Name) + } +} + +// TestParser_ArraySubscript_OnParenthesizedExpr tests array subscript on parenthesized expression +func TestParser_ArraySubscript_OnParenthesizedExpr(t *testing.T) { + sql := "SELECT (arr)[1] FROM table_name" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(parserTokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + if len(tree.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(tree.Statements)) + } + + stmt, ok := tree.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", tree.Statements[0]) + } + + if len(stmt.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(stmt.Columns)) + } + + // The column should be an ArraySubscriptExpression + subscriptExpr, ok := stmt.Columns[0].(*ast.ArraySubscriptExpression) + if !ok { + t.Fatalf("expected ArraySubscriptExpression, got %T", stmt.Columns[0]) + } + + // The array should be an identifier + arrayIdent, ok := subscriptExpr.Array.(*ast.Identifier) + if !ok { + t.Fatalf("expected array to be Identifier, got %T", subscriptExpr.Array) + } + if arrayIdent.Name != "arr" { + t.Errorf("expected array name 'arr', got '%s'", arrayIdent.Name) + } +} + +// TestParser_ArraySubscript_ErrorEmptyBrackets tests error on empty brackets +func TestParser_ArraySubscript_ErrorEmptyBrackets(t *testing.T) { + sql := "SELECT arr[] FROM table_name" + + // Tokenize the SQL + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("tokenizer error: %v", err) + } + + // Convert to parser tokens + parserTokens, err := ConvertTokensForParser(tokens) + if err != nil { + t.Fatalf("token conversion error: %v", err) + } + + parser := NewParser() + defer parser.Release() + + _, err = parser.Parse(parserTokens) + if err == nil { + t.Fatal("expected error for empty array brackets, got nil") + } + + // The error should mention empty brackets + errMsg := err.Error() + if errMsg == "" { + t.Error("expected non-empty error message") + } +} + +// TestParser_ArraySubscriptExpression_Children tests that Children() returns correct nodes +func TestParser_ArraySubscriptExpression_Children(t *testing.T) { + arrayIdent := &ast.Identifier{Name: "arr"} + index1 := &ast.LiteralValue{Value: "1", Type: "int"} + index2 := &ast.LiteralValue{Value: "2", Type: "int"} + + subscriptExpr := &ast.ArraySubscriptExpression{ + Array: arrayIdent, + Indices: []ast.Expression{index1, index2}, + } + + children := subscriptExpr.Children() + if len(children) != 3 { + t.Errorf("expected 3 children, got %d", len(children)) + } +} + +// TestParser_ArraySliceExpression_Children tests that Children() returns correct nodes +func TestParser_ArraySliceExpression_Children(t *testing.T) { + arrayIdent := &ast.Identifier{Name: "arr"} + start := &ast.LiteralValue{Value: "1", Type: "int"} + end := &ast.LiteralValue{Value: "3", Type: "int"} + + sliceExpr := &ast.ArraySliceExpression{ + Array: arrayIdent, + Start: start, + End: end, + } + + children := sliceExpr.Children() + // Should have 3: array, start, end + if len(children) != 3 { + t.Errorf("expected 3 children, got %d", len(children)) + } +} diff --git a/pkg/sql/parser/between_complex_test.go b/pkg/sql/parser/between_complex_test.go new file mode 100644 index 0000000..ce3a8a2 --- /dev/null +++ b/pkg/sql/parser/between_complex_test.go @@ -0,0 +1,559 @@ +// Package parser - between_complex_test.go +// Comprehensive tests for BETWEEN with complex expressions (Issue #180) + +package parser + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/token" +) + +// TestParser_BetweenWithIntervalArithmetic tests BETWEEN with INTERVAL expressions +// Example: SELECT * FROM orders WHERE created_at BETWEEN NOW() - INTERVAL '30 days' AND NOW() +func TestParser_BetweenWithIntervalArithmetic(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "orders"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "created_at"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "IDENT", Literal: "NOW"}, + {Type: "(", Literal: "("}, + {Type: ")", Literal: ")"}, + {Type: "MINUS", Literal: "-"}, + {Type: "INTERVAL", Literal: "INTERVAL"}, + {Type: "STRING", Literal: "30 days"}, + {Type: "AND", Literal: "AND"}, + {Type: "IDENT", Literal: "NOW"}, + {Type: "(", Literal: "("}, + {Type: ")", Literal: ")"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify main expression is 'created_at' + ident, ok := betweenExpr.Expr.(*ast.Identifier) + if !ok { + t.Fatalf("expected Expr to be Identifier, got %T", betweenExpr.Expr) + } + if ident.Name != "created_at" { + t.Errorf("expected Expr name 'created_at', got '%s'", ident.Name) + } + + // Verify lower bound is a binary expression (NOW() - INTERVAL '30 days') + lowerBinary, ok := betweenExpr.Lower.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected lower bound to be BinaryExpression, got %T", betweenExpr.Lower) + } + if lowerBinary.Operator != "-" { + t.Errorf("expected lower bound operator '-', got '%s'", lowerBinary.Operator) + } + + // Verify lower bound left side is NOW() function call + lowerFunc, ok := lowerBinary.Left.(*ast.FunctionCall) + if !ok { + t.Fatalf("expected lower bound left to be FunctionCall, got %T", lowerBinary.Left) + } + if lowerFunc.Name != "NOW" { + t.Errorf("expected function name 'NOW', got '%s'", lowerFunc.Name) + } + + // Verify lower bound right side is INTERVAL expression + intervalExpr, ok := lowerBinary.Right.(*ast.IntervalExpression) + if !ok { + t.Fatalf("expected lower bound right to be IntervalExpression, got %T", lowerBinary.Right) + } + if intervalExpr.Value != "30 days" { + t.Errorf("expected interval value '30 days', got '%s'", intervalExpr.Value) + } + + // Verify upper bound is NOW() function call + upperFunc, ok := betweenExpr.Upper.(*ast.FunctionCall) + if !ok { + t.Fatalf("expected upper bound to be FunctionCall, got %T", betweenExpr.Upper) + } + if upperFunc.Name != "NOW" { + t.Errorf("expected function name 'NOW', got '%s'", upperFunc.Name) + } +} + +// TestParser_BetweenWithSubqueries tests BETWEEN with subquery expressions +// Example: SELECT * FROM data WHERE value BETWEEN (SELECT min_val FROM limits) AND (SELECT max_val FROM limits) +func TestParser_BetweenWithSubqueries(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "data"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "value"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "(", Literal: "("}, + {Type: "SELECT", Literal: "SELECT"}, + {Type: "IDENT", Literal: "min_val"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "limits"}, + {Type: ")", Literal: ")"}, + {Type: "AND", Literal: "AND"}, + {Type: "(", Literal: "("}, + {Type: "SELECT", Literal: "SELECT"}, + {Type: "IDENT", Literal: "max_val"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "limits"}, + {Type: ")", Literal: ")"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify main expression is 'value' + ident, ok := betweenExpr.Expr.(*ast.Identifier) + if !ok { + t.Fatalf("expected Expr to be Identifier, got %T", betweenExpr.Expr) + } + if ident.Name != "value" { + t.Errorf("expected Expr name 'value', got '%s'", ident.Name) + } + + // Verify lower bound is a subquery + lowerSubquery, ok := betweenExpr.Lower.(*ast.SubqueryExpression) + if !ok { + t.Fatalf("expected lower bound to be SubqueryExpression, got %T", betweenExpr.Lower) + } + + lowerSelect, ok := lowerSubquery.Subquery.(*ast.SelectStatement) + if !ok { + t.Fatalf("expected lower subquery to be SelectStatement, got %T", lowerSubquery.Subquery) + } + if len(lowerSelect.Columns) != 1 { + t.Errorf("expected 1 column in lower subquery, got %d", len(lowerSelect.Columns)) + } + + // Verify upper bound is a subquery + upperSubquery, ok := betweenExpr.Upper.(*ast.SubqueryExpression) + if !ok { + t.Fatalf("expected upper bound to be SubqueryExpression, got %T", betweenExpr.Upper) + } + + upperSelect, ok := upperSubquery.Subquery.(*ast.SelectStatement) + if !ok { + t.Fatalf("expected upper subquery to be SelectStatement, got %T", upperSubquery.Subquery) + } + if len(upperSelect.Columns) != 1 { + t.Errorf("expected 1 column in upper subquery, got %d", len(upperSelect.Columns)) + } +} + +// TestParser_BetweenWithMixedComplexExpressions tests BETWEEN with various complex expression types +// Example: SELECT * FROM sales WHERE amount BETWEEN (price * 0.8) + discount AND (price * 1.2) - fee +func TestParser_BetweenWithMixedComplexExpressions(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "sales"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "amount"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "price"}, + {Type: "*", Literal: "*"}, + {Type: "FLOAT", Literal: "0.8"}, + {Type: ")", Literal: ")"}, + {Type: "PLUS", Literal: "+"}, + {Type: "IDENT", Literal: "discount"}, + {Type: "AND", Literal: "AND"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "price"}, + {Type: "*", Literal: "*"}, + {Type: "FLOAT", Literal: "1.2"}, + {Type: ")", Literal: ")"}, + {Type: "MINUS", Literal: "-"}, + {Type: "IDENT", Literal: "fee"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify main expression + ident, ok := betweenExpr.Expr.(*ast.Identifier) + if !ok { + t.Fatalf("expected Expr to be Identifier, got %T", betweenExpr.Expr) + } + if ident.Name != "amount" { + t.Errorf("expected Expr name 'amount', got '%s'", ident.Name) + } + + // Verify lower bound is addition: (price * 0.8) + discount + lowerBinary, ok := betweenExpr.Lower.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected lower bound to be BinaryExpression, got %T", betweenExpr.Lower) + } + if lowerBinary.Operator != "+" { + t.Errorf("expected lower bound operator '+', got '%s'", lowerBinary.Operator) + } + + // Verify upper bound is subtraction: (price * 1.2) - fee + upperBinary, ok := betweenExpr.Upper.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected upper bound to be BinaryExpression, got %T", betweenExpr.Upper) + } + if upperBinary.Operator != "-" { + t.Errorf("expected upper bound operator '-', got '%s'", upperBinary.Operator) + } +} + +// TestParser_BetweenWithNestedFunctionCalls tests BETWEEN with nested function calls +// Example: SELECT * FROM metrics WHERE score BETWEEN ROUND(AVG(baseline)) AND CEIL(MAX(threshold)) +func TestParser_BetweenWithNestedFunctionCalls(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "metrics"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "score"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "IDENT", Literal: "ROUND"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "AVG"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "baseline"}, + {Type: ")", Literal: ")"}, + {Type: ")", Literal: ")"}, + {Type: "AND", Literal: "AND"}, + {Type: "IDENT", Literal: "CEIL"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "MAX"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "threshold"}, + {Type: ")", Literal: ")"}, + {Type: ")", Literal: ")"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify lower bound is ROUND function with nested AVG + lowerFunc, ok := betweenExpr.Lower.(*ast.FunctionCall) + if !ok { + t.Fatalf("expected lower bound to be FunctionCall, got %T", betweenExpr.Lower) + } + if lowerFunc.Name != "ROUND" { + t.Errorf("expected lower function name 'ROUND', got '%s'", lowerFunc.Name) + } + if len(lowerFunc.Arguments) != 1 { + t.Errorf("expected 1 argument for ROUND, got %d", len(lowerFunc.Arguments)) + } + + // Verify nested AVG function + nestedAvg, ok := lowerFunc.Arguments[0].(*ast.FunctionCall) + if !ok { + t.Fatalf("expected nested function to be FunctionCall, got %T", lowerFunc.Arguments[0]) + } + if nestedAvg.Name != "AVG" { + t.Errorf("expected nested function name 'AVG', got '%s'", nestedAvg.Name) + } + + // Verify upper bound is CEIL function with nested MAX + upperFunc, ok := betweenExpr.Upper.(*ast.FunctionCall) + if !ok { + t.Fatalf("expected upper bound to be FunctionCall, got %T", betweenExpr.Upper) + } + if upperFunc.Name != "CEIL" { + t.Errorf("expected upper function name 'CEIL', got '%s'", upperFunc.Name) + } + + // Verify nested MAX function + nestedMax, ok := upperFunc.Arguments[0].(*ast.FunctionCall) + if !ok { + t.Fatalf("expected nested function to be FunctionCall, got %T", upperFunc.Arguments[0]) + } + if nestedMax.Name != "MAX" { + t.Errorf("expected nested function name 'MAX', got '%s'", nestedMax.Name) + } +} + +// TestParser_BetweenWithCastExpressions tests BETWEEN with CAST expressions +// Example: SELECT * FROM products WHERE price BETWEEN CAST(min_price AS DECIMAL) AND CAST(max_price AS DECIMAL) +func TestParser_BetweenWithCastExpressions(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "products"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "price"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "CAST", Literal: "CAST"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "min_price"}, + {Type: "AS", Literal: "AS"}, + {Type: "IDENT", Literal: "DECIMAL"}, + {Type: ")", Literal: ")"}, + {Type: "AND", Literal: "AND"}, + {Type: "CAST", Literal: "CAST"}, + {Type: "(", Literal: "("}, + {Type: "IDENT", Literal: "max_price"}, + {Type: "AS", Literal: "AS"}, + {Type: "IDENT", Literal: "DECIMAL"}, + {Type: ")", Literal: ")"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify lower bound is a CAST expression + lowerCast, ok := betweenExpr.Lower.(*ast.CastExpression) + if !ok { + t.Fatalf("expected lower bound to be CastExpression, got %T", betweenExpr.Lower) + } + if lowerCast.Type != "DECIMAL" { + t.Errorf("expected lower cast type 'DECIMAL', got '%s'", lowerCast.Type) + } + + // Verify upper bound is a CAST expression + upperCast, ok := betweenExpr.Upper.(*ast.CastExpression) + if !ok { + t.Fatalf("expected upper bound to be CastExpression, got %T", betweenExpr.Upper) + } + if upperCast.Type != "DECIMAL" { + t.Errorf("expected upper cast type 'DECIMAL', got '%s'", upperCast.Type) + } +} + +// TestParser_BetweenWithCaseExpressions tests BETWEEN with CASE expressions +// Example: SELECT * FROM orders WHERE total BETWEEN CASE WHEN discount THEN 100 ELSE 200 END AND 1000 +func TestParser_BetweenWithCaseExpressions(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "orders"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "total"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "CASE", Literal: "CASE"}, + {Type: "WHEN", Literal: "WHEN"}, + {Type: "IDENT", Literal: "discount"}, + {Type: "THEN", Literal: "THEN"}, + {Type: "INT", Literal: "100"}, + {Type: "ELSE", Literal: "ELSE"}, + {Type: "INT", Literal: "200"}, + {Type: "END", Literal: "END"}, + {Type: "AND", Literal: "AND"}, + {Type: "INT", Literal: "1000"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify lower bound is a CASE expression + lowerCase, ok := betweenExpr.Lower.(*ast.CaseExpression) + if !ok { + t.Fatalf("expected lower bound to be CaseExpression, got %T", betweenExpr.Lower) + } + if len(lowerCase.WhenClauses) != 1 { + t.Errorf("expected 1 WHEN clause, got %d", len(lowerCase.WhenClauses)) + } + + // Verify upper bound is a literal + upperLit, ok := betweenExpr.Upper.(*ast.LiteralValue) + if !ok { + t.Fatalf("expected upper bound to be LiteralValue, got %T", betweenExpr.Upper) + } + if upperLit.Value != "1000" { + t.Errorf("expected upper bound value '1000', got '%v'", upperLit.Value) + } +} + +// TestParser_NotBetweenWithComplexExpressions tests NOT BETWEEN with complex expressions +// Example: SELECT * FROM products WHERE price NOT BETWEEN price * 0.5 AND price * 2 +func TestParser_NotBetweenWithComplexExpressions(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "products"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "price"}, + {Type: "NOT", Literal: "NOT"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "IDENT", Literal: "price"}, + {Type: "*", Literal: "*"}, + {Type: "FLOAT", Literal: "0.5"}, + {Type: "AND", Literal: "AND"}, + {Type: "IDENT", Literal: "price"}, + {Type: "*", Literal: "*"}, + {Type: "INT", Literal: "2"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify NOT flag is set + if !betweenExpr.Not { + t.Error("expected NOT BETWEEN, but Not flag is false") + } + + // Verify lower bound is multiplication + lowerBinary, ok := betweenExpr.Lower.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected lower bound to be BinaryExpression, got %T", betweenExpr.Lower) + } + if lowerBinary.Operator != "*" { + t.Errorf("expected lower bound operator '*', got '%s'", lowerBinary.Operator) + } + + // Verify upper bound is multiplication + upperBinary, ok := betweenExpr.Upper.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected upper bound to be BinaryExpression, got %T", betweenExpr.Upper) + } + if upperBinary.Operator != "*" { + t.Errorf("expected upper bound operator '*', got '%s'", upperBinary.Operator) + } +} + +// TestParser_BetweenWithStringConcatenation tests BETWEEN with string concatenation +// Example: SELECT * FROM users WHERE full_name BETWEEN first_name || ' A' AND first_name || ' Z' +func TestParser_BetweenWithStringConcatenation(t *testing.T) { + tokens := []token.Token{ + {Type: "SELECT", Literal: "SELECT"}, + {Type: "*", Literal: "*"}, + {Type: "FROM", Literal: "FROM"}, + {Type: "IDENT", Literal: "users"}, + {Type: "WHERE", Literal: "WHERE"}, + {Type: "IDENT", Literal: "full_name"}, + {Type: "BETWEEN", Literal: "BETWEEN"}, + {Type: "IDENT", Literal: "first_name"}, + {Type: "STRING_CONCAT", Literal: "||"}, + {Type: "STRING", Literal: " A"}, + {Type: "AND", Literal: "AND"}, + {Type: "IDENT", Literal: "first_name"}, + {Type: "STRING_CONCAT", Literal: "||"}, + {Type: "STRING", Literal: " Z"}, + } + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(tokens) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer ast.ReleaseAST(tree) + + stmt := tree.Statements[0].(*ast.SelectStatement) + betweenExpr, ok := stmt.Where.(*ast.BetweenExpression) + if !ok { + t.Fatalf("expected WHERE to be BetweenExpression, got %T", stmt.Where) + } + + // Verify lower bound is string concatenation + lowerBinary, ok := betweenExpr.Lower.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected lower bound to be BinaryExpression, got %T", betweenExpr.Lower) + } + if lowerBinary.Operator != "||" { + t.Errorf("expected lower bound operator '||', got '%s'", lowerBinary.Operator) + } + + // Verify upper bound is string concatenation + upperBinary, ok := betweenExpr.Upper.(*ast.BinaryExpression) + if !ok { + t.Fatalf("expected upper bound to be BinaryExpression, got %T", betweenExpr.Upper) + } + if upperBinary.Operator != "||" { + t.Errorf("expected upper bound operator '||', got '%s'", upperBinary.Operator) + } +} diff --git a/pkg/sql/parser/expressions.go b/pkg/sql/parser/expressions.go index 5b63ee2..818be6e 100644 --- a/pkg/sql/parser/expressions.go +++ b/pkg/sql/parser/expressions.go @@ -657,6 +657,12 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { } } + // Check for array subscript or slice syntax: identifier[...] + // This handles: arr[1], arr[1][2], arr[1:3], arr[2:], arr[:5] + if p.isType(models.TokenTypeLBracket) { + return p.parseArrayAccessExpression(ident) + } + return ident, nil } @@ -763,6 +769,13 @@ func (p *Parser) parsePrimaryExpression() (ast.Expression, error) { return nil, p.expectedError(")") } p.advance() // Consume ) + + // Check for array subscript or slice on parenthesized expression + // This handles: (expr)[1], (SELECT arr)[2:3] + if p.isType(models.TokenTypeLBracket) { + return p.parseArrayAccessExpression(expr) + } + return expr, nil } @@ -1173,3 +1186,132 @@ func (p *Parser) parseSubquery() (ast.Statement, error) { // // COUNT(*) -> regular aggregate function // ROW_NUMBER() OVER (ORDER BY id) -> window function with OVER clause + +// parseArrayAccessExpression parses array subscript and slice expressions. +// +// Supports: +// - Single subscript: arr[1] +// - Multi-dimensional subscript: arr[1][2][3] +// - Slice with both bounds: arr[1:3] +// - Slice from start: arr[:5] +// - Slice to end: arr[2:] +// - Full slice: arr[:] +// +// Examples: +// +// tags[1] -> ArraySubscriptExpression with single index +// matrix[2][3] -> Nested ArraySubscriptExpression (multi-dimensional) +// arr[1:3] -> ArraySliceExpression with start and end +// arr[2:] -> ArraySliceExpression with start only +// arr[:5] -> ArraySliceExpression with end only +// (SELECT arr)[1] -> Array access on subquery result +func (p *Parser) parseArrayAccessExpression(arrayExpr ast.Expression) (ast.Expression, error) { + // arrayExpr is the expression before the first '[' + // We need to parse one or more '[...]' subscripts/slices + + result := arrayExpr + + // Loop to handle chained subscripts: arr[1][2][3] + for p.isType(models.TokenTypeLBracket) { + p.advance() // Consume [ + + // Check for empty brackets [] - this is an error + if p.isType(models.TokenTypeRBracket) { + return nil, goerrors.InvalidSyntaxError( + "empty array subscript [] is not allowed", + p.currentLocation(), + "Use arr[index] or arr[start:end] syntax", + ) + } + + // Check for slice starting with colon: arr[:end] + if p.isType(models.TokenTypeColon) { + p.advance() // Consume : + + // Parse end expression (if not ']') + var endExpr ast.Expression + if !p.isType(models.TokenTypeRBracket) { + end, err := p.parseExpression() + if err != nil { + return nil, goerrors.InvalidSyntaxError( + fmt.Sprintf("failed to parse array slice end: %v", err), + p.currentLocation(), + "", + ) + } + endExpr = end + } + + // Expect closing bracket + if !p.isType(models.TokenTypeRBracket) { + return nil, p.expectedError("]") + } + p.advance() // Consume ] + + // Create ArraySliceExpression with no start + sliceExpr := ast.GetArraySliceExpression() + sliceExpr.Array = result + sliceExpr.Start = nil + sliceExpr.End = endExpr + result = sliceExpr + continue + } + + // Parse first expression (index or slice start) + firstExpr, err := p.parseExpression() + if err != nil { + return nil, goerrors.InvalidSyntaxError( + fmt.Sprintf("failed to parse array index/slice: %v", err), + p.currentLocation(), + "", + ) + } + + // Check if this is a slice (has colon) or subscript + if p.isType(models.TokenTypeColon) { + p.advance() // Consume : + + // Parse end expression (if not ']') + var endExpr ast.Expression + if !p.isType(models.TokenTypeRBracket) { + end, err := p.parseExpression() + if err != nil { + return nil, goerrors.InvalidSyntaxError( + fmt.Sprintf("failed to parse array slice end: %v", err), + p.currentLocation(), + "", + ) + } + endExpr = end + } + + // Expect closing bracket + if !p.isType(models.TokenTypeRBracket) { + return nil, p.expectedError("]") + } + p.advance() // Consume ] + + // Create ArraySliceExpression + sliceExpr := ast.GetArraySliceExpression() + sliceExpr.Array = result + sliceExpr.Start = firstExpr + sliceExpr.End = endExpr + result = sliceExpr + } else { + // This is a subscript, not a slice + // Expect closing bracket + if !p.isType(models.TokenTypeRBracket) { + return nil, p.expectedError("]") + } + p.advance() // Consume ] + + // Create ArraySubscriptExpression with single index + subscriptExpr := ast.GetArraySubscriptExpression() + subscriptExpr.Array = result + subscriptExpr.Indices = append(subscriptExpr.Indices, firstExpr) + result = subscriptExpr + } + } + + return result, nil +} diff --git a/pkg/sql/parser/token_converter.go b/pkg/sql/parser/token_converter.go index 828ac38..82e96e0 100644 --- a/pkg/sql/parser/token_converter.go +++ b/pkg/sql/parser/token_converter.go @@ -781,6 +781,12 @@ func buildTypeMapping() map[models.TokenType]token.Type { models.TokenTypeQuestion: "QUESTION", // ? key exists models.TokenTypeQuestionPipe: "QUESTION_PIPE", // ?| any keys exist models.TokenTypeQuestionAnd: "QUESTION_AND", // ?& all keys exist + + // PostgreSQL regex operators + models.TokenTypeTilde: "~", // ~ case-sensitive regex match + models.TokenTypeTildeAsterisk: "~*", // ~* case-insensitive regex match + models.TokenTypeExclamationMarkTilde: "!~", // !~ case-sensitive regex non-match + models.TokenTypeExclamationMarkTildeAsterisk: "!~*", // !~* case-insensitive regex non-match } }