diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index cc9892a..e4b3f3e 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -23,8 +23,13 @@ type Expression interface { } // WithClause represents a WITH clause in a SQL statement +// TODO: PHASE 2 - Complete CTE implementation +// Current Status: AST structures defined, parser integration incomplete +// Missing: parseWithClause, parseCommonTableExpr, parseStatementWithSetOps functions +// Priority: High (Phase 2 core feature) type WithClause struct { - CTEs []*CommonTableExpr + Recursive bool + CTEs []*CommonTableExpr } func (w *WithClause) statementNode() {} @@ -38,11 +43,14 @@ func (w WithClause) Children() []Node { } // CommonTableExpr represents a single CTE in a WITH clause +// TODO: PHASE 2 - Parser integration needed for CTE functionality +// Current: AST structure complete, parser functions missing +// Required: Integration with SELECT/INSERT/UPDATE/DELETE statement parsing type CommonTableExpr struct { Name string Columns []string Statement Statement - Materialized *bool + Materialized *bool // TODO: Add MATERIALIZED/NOT MATERIALIZED parsing support } func (c *CommonTableExpr) statementNode() {} diff --git a/pkg/sql/parser/join_test.go b/pkg/sql/parser/join_test.go new file mode 100644 index 0000000..1b5e21d --- /dev/null +++ b/pkg/sql/parser/join_test.go @@ -0,0 +1,557 @@ +package parser + +import ( + "fmt" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/token" + "github.com/ajitpratap0/GoSQLX/pkg/sql/tokenizer" +) + +// convertTokens converts models.TokenWithSpan to token.Token for parser +func convertTokens(tokens []models.TokenWithSpan) []token.Token { + result := make([]token.Token, 0, len(tokens)*2) // Extra space for split tokens + + for _, t := range tokens { + // Handle compound JOIN tokens by splitting them + switch t.Token.Type { + case models.TokenTypeInnerJoin: + result = append(result, token.Token{Type: "INNER", Literal: "INNER"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + case models.TokenTypeLeftJoin: + result = append(result, token.Token{Type: "LEFT", Literal: "LEFT"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + case models.TokenTypeRightJoin: + result = append(result, token.Token{Type: "RIGHT", Literal: "RIGHT"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + case models.TokenTypeOuterJoin: + result = append(result, token.Token{Type: "OUTER", Literal: "OUTER"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } + + // Handle compound tokens that come as strings + if t.Token.Value == "INNER JOIN" { + result = append(result, token.Token{Type: "INNER", Literal: "INNER"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "LEFT JOIN" { + result = append(result, token.Token{Type: "LEFT", Literal: "LEFT"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "RIGHT JOIN" { + result = append(result, token.Token{Type: "RIGHT", Literal: "RIGHT"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "FULL JOIN" || t.Token.Type == models.TokenTypeKeyword && t.Token.Value == "FULL JOIN" { + result = append(result, token.Token{Type: "FULL", Literal: "FULL"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "CROSS JOIN" || t.Token.Type == models.TokenTypeKeyword && t.Token.Value == "CROSS JOIN" { + result = append(result, token.Token{Type: "CROSS", Literal: "CROSS"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "LEFT OUTER JOIN" { + result = append(result, token.Token{Type: "LEFT", Literal: "LEFT"}) + result = append(result, token.Token{Type: "OUTER", Literal: "OUTER"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "RIGHT OUTER JOIN" { + result = append(result, token.Token{Type: "RIGHT", Literal: "RIGHT"}) + result = append(result, token.Token{Type: "OUTER", Literal: "OUTER"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "FULL OUTER JOIN" { + result = append(result, token.Token{Type: "FULL", Literal: "FULL"}) + result = append(result, token.Token{Type: "OUTER", Literal: "OUTER"}) + result = append(result, token.Token{Type: "JOIN", Literal: "JOIN"}) + continue + } else if t.Token.Value == "ORDER BY" || t.Token.Type == models.TokenTypeOrderBy { + result = append(result, token.Token{Type: "ORDER", Literal: "ORDER"}) + result = append(result, token.Token{Type: "BY", Literal: "BY"}) + continue + } else if t.Token.Value == "GROUP BY" || t.Token.Type == models.TokenTypeGroupBy { + result = append(result, token.Token{Type: "GROUP", Literal: "GROUP"}) + result = append(result, token.Token{Type: "BY", Literal: "BY"}) + continue + } + + // Map token type to string for single tokens + tokenType := token.Type(fmt.Sprintf("%v", t.Token.Type)) + + // Try to map to proper token type string + switch t.Token.Type { + case models.TokenTypeSelect: + tokenType = "SELECT" + case models.TokenTypeFrom: + tokenType = "FROM" + case models.TokenTypeWhere: + tokenType = "WHERE" + case models.TokenTypeJoin: + tokenType = "JOIN" + case models.TokenTypeInner: + tokenType = "INNER" + case models.TokenTypeLeft: + tokenType = "LEFT" + case models.TokenTypeRight: + tokenType = "RIGHT" + case models.TokenTypeOuter: + tokenType = "OUTER" + case models.TokenTypeOn: + tokenType = "ON" + case models.TokenTypeAs: + tokenType = "AS" + case models.TokenTypeIdentifier: + tokenType = "IDENT" + case models.TokenTypeMul: + tokenType = "*" + case models.TokenTypeEq: + tokenType = "=" + case models.TokenTypePeriod: + tokenType = "." + case models.TokenTypeLParen: + tokenType = "(" + case models.TokenTypeRParen: + tokenType = ")" + case models.TokenTypeComma: + tokenType = "," + case models.TokenTypeOrder: + tokenType = "ORDER" + case models.TokenTypeBy: + tokenType = "BY" + case models.TokenTypeDesc: + tokenType = "DESC" + case models.TokenTypeAsc: + tokenType = "ASC" + case models.TokenTypeLimit: + tokenType = "LIMIT" + case models.TokenTypeTrue: + tokenType = "TRUE" + case models.TokenTypeNumber: + tokenType = "INT" + case models.TokenTypeEOF: + tokenType = "EOF" + default: + // For any other type, use the value as the type if it looks like a keyword + // This handles keywords like FULL, CROSS, USING that don't have specific token types + if t.Token.Value != "" { + tokenType = token.Type(t.Token.Value) + } + // Special handling for keywords that come through as TokenTypeKeyword + if t.Token.Type == models.TokenTypeKeyword { + tokenType = token.Type(t.Token.Value) + } + } + + result = append(result, token.Token{ + Type: tokenType, + Literal: t.Token.Value, + }) + } + return result +} + +func TestParser_JoinTypes(t *testing.T) { + tests := []struct { + name string + sql string + joinType string + wantErr bool + }{ + { + name: "INNER JOIN", + sql: "SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id", + joinType: "INNER", + wantErr: false, + }, + { + name: "LEFT JOIN", + sql: "SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id", + joinType: "LEFT", + wantErr: false, + }, + { + name: "LEFT OUTER JOIN", + sql: "SELECT * FROM users LEFT OUTER JOIN orders ON users.id = orders.user_id", + joinType: "LEFT", + wantErr: false, + }, + { + name: "RIGHT JOIN", + sql: "SELECT * FROM users RIGHT JOIN orders ON users.id = orders.user_id", + joinType: "RIGHT", + wantErr: false, + }, + { + name: "RIGHT OUTER JOIN", + sql: "SELECT * FROM users RIGHT OUTER JOIN orders ON users.id = orders.user_id", + joinType: "RIGHT", + wantErr: false, + }, + { + name: "FULL JOIN", + sql: "SELECT * FROM users FULL JOIN orders ON users.id = orders.user_id", + joinType: "FULL", + wantErr: false, + }, + { + name: "FULL OUTER JOIN", + sql: "SELECT * FROM users FULL OUTER JOIN orders ON users.id = orders.user_id", + joinType: "FULL", + wantErr: false, + }, + { + name: "CROSS JOIN", + sql: "SELECT * FROM users CROSS JOIN products", + joinType: "CROSS", + wantErr: false, + }, + { + name: "Multiple JOINs", + sql: "SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id RIGHT JOIN products ON orders.product_id = products.id", + joinType: "LEFT", // First join + wantErr: false, + }, + { + name: "JOIN with table alias", + sql: "SELECT * FROM users u LEFT JOIN orders o ON u.id = o.user_id", + joinType: "LEFT", + wantErr: false, + }, + { + name: "JOIN with AS alias", + sql: "SELECT * FROM users AS u LEFT JOIN orders AS o ON u.id = o.user_id", + joinType: "LEFT", + wantErr: false, + }, + { + name: "JOIN with USING", + sql: "SELECT * FROM users LEFT JOIN orders USING (id)", + joinType: "LEFT", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Get tokenizer from pool + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + // Tokenize SQL + tokens, err := tkz.Tokenize([]byte(tt.sql)) + if err != nil { + t.Fatalf("Failed to tokenize: %v", err) + } + + // Convert tokens for parser + convertedTokens := convertTokens(tokens) + + // Parse tokens + parser := &Parser{} + astObj, err := parser.Parse(convertedTokens) + if (err != nil) != tt.wantErr { + t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && astObj != nil { + defer ast.ReleaseAST(astObj) + + // Check if we have a SELECT statement + if len(astObj.Statements) > 0 { + if selectStmt, ok := astObj.Statements[0].(*ast.SelectStatement); ok { + // Check JOIN type for first join + if len(selectStmt.Joins) > 0 { + if selectStmt.Joins[0].Type != tt.joinType { + t.Errorf("Expected join type %s, got %s", tt.joinType, selectStmt.Joins[0].Type) + } + } else if tt.joinType != "" { + t.Errorf("Expected join clause but found none") + } + } else { + t.Errorf("Expected SELECT statement") + } + } + } + }) + } +} + +func TestParser_ComplexJoins(t *testing.T) { + sql := ` + SELECT + u.name, + o.order_date, + p.product_name, + c.category_name + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + INNER JOIN products p ON o.product_id = p.id + RIGHT JOIN categories c ON p.category_id = c.id + WHERE u.active = true + ORDER BY o.order_date DESC + LIMIT 100 + ` + + // Get tokenizer from pool + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + // Tokenize SQL + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("Failed to tokenize: %v", err) + } + + // Convert tokens for parser + convertedTokens := convertTokens(tokens) + + // Parse tokens + parser := &Parser{} + astObj, err := parser.Parse(convertedTokens) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + defer ast.ReleaseAST(astObj) + + // Verify we have a SELECT statement + if len(astObj.Statements) == 0 { + t.Fatal("No statements parsed") + } + + selectStmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatal("Expected SELECT statement") + } + + // Verify we have 3 JOINs + if len(selectStmt.Joins) != 3 { + t.Errorf("Expected 3 JOINs, got %d", len(selectStmt.Joins)) + } + + // Verify JOIN types + expectedJoinTypes := []string{"LEFT", "INNER", "RIGHT"} + for i, expectedType := range expectedJoinTypes { + if i < len(selectStmt.Joins) { + if selectStmt.Joins[i].Type != expectedType { + t.Errorf("Join %d: expected type %s, got %s", i, expectedType, selectStmt.Joins[i].Type) + } + } + } + + // Verify we have WHERE, ORDER BY, and LIMIT + if selectStmt.Where == nil { + t.Error("Expected WHERE clause") + } + if len(selectStmt.OrderBy) == 0 { + t.Error("Expected ORDER BY clause") + } + if selectStmt.Limit == nil { + t.Error("Expected LIMIT clause") + } +} + +func TestParser_InvalidJoinSyntax(t *testing.T) { + tests := []struct { + name string + sql string + expectedError string + }{ + { + name: "Missing JOIN keyword after type", + sql: "SELECT * FROM users LEFT orders ON users.id = orders.user_id", + expectedError: "expected JOIN after LEFT", + }, + { + name: "Missing table name after JOIN", + sql: "SELECT * FROM users LEFT JOIN ON users.id = orders.user_id", + expectedError: "expected table name after LEFT JOIN", + }, + { + name: "Missing ON/USING clause", + sql: "SELECT * FROM users LEFT JOIN orders", + expectedError: "expected ON or USING", + }, + { + name: "Invalid JOIN type", + sql: "SELECT * FROM users INVALID JOIN orders ON users.id = orders.user_id", + expectedError: "", // This won't error as INVALID becomes an identifier + }, + { + name: "Missing condition after ON", + sql: "SELECT * FROM users LEFT JOIN orders ON", + expectedError: "error parsing ON condition", + }, + { + name: "Missing parentheses after USING", + sql: "SELECT * FROM users LEFT JOIN orders USING id", + expectedError: "expected ( after USING", + }, + { + name: "Empty USING clause", + sql: "SELECT * FROM users LEFT JOIN orders USING ()", + expectedError: "expected column name in USING", + }, + { + name: "Incomplete OUTER JOIN", + sql: "SELECT * FROM users OUTER JOIN orders ON users.id = orders.user_id", + expectedError: "expected statement", + }, + { + name: "JOIN without FROM clause", + sql: "SELECT * LEFT JOIN orders ON users.id = orders.user_id", + expectedError: "", // This errors during FROM parsing, not JOIN parsing + }, + { + name: "Multiple JOIN keywords", + sql: "SELECT * FROM users JOIN JOIN orders ON users.id = orders.user_id", + expectedError: "expected table name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Get tokenizer from pool + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + // Tokenize SQL + tokens, err := tkz.Tokenize([]byte(tt.sql)) + if err != nil { + // Some tests might fail at tokenization level + if tt.expectedError != "" { + return // Expected failure + } + t.Fatalf("Failed to tokenize: %v", err) + } + + // Convert tokens for parser + convertedTokens := convertTokens(tokens) + + // Parse tokens + parser := &Parser{} + astObj, err := parser.Parse(convertedTokens) + + if tt.expectedError != "" { + // We expect an error + if err == nil { + defer ast.ReleaseAST(astObj) + t.Errorf("Expected error containing '%s', but got no error", tt.expectedError) + } else if !containsError(err.Error(), tt.expectedError) { + t.Errorf("Expected error containing '%s', got '%s'", tt.expectedError, err.Error()) + } + } else { + // We don't expect an error for some edge cases + if err != nil && tt.expectedError == "" { + // Some tests intentionally have no expected error + // because they fail in different ways + return + } + if astObj != nil { + defer ast.ReleaseAST(astObj) + } + } + }) + } +} + +// Helper function to check if error message contains expected text +func containsError(actual, expected string) bool { + if expected == "" { + return true + } + return len(actual) > 0 && len(expected) > 0 && + (actual == expected || + len(actual) >= len(expected) && + (actual[:len(expected)] == expected || + actual[len(actual)-len(expected):] == expected || + containsSubstring(actual, expected))) +} + +// Simple substring check +func containsSubstring(s, substr string) bool { + if len(substr) > len(s) { + return false + } + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func TestParser_JoinTreeLogic(t *testing.T) { + sql := "SELECT * FROM users u LEFT JOIN orders o ON u.id = o.user_id INNER JOIN products p ON o.product_id = p.id" + + // Get tokenizer from pool + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + // Tokenize SQL + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + t.Fatalf("Failed to tokenize: %v", err) + } + + // Convert tokens for parser + convertedTokens := convertTokens(tokens) + + // Parse tokens + parser := &Parser{} + astObj, err := parser.Parse(convertedTokens) + if err != nil { + t.Fatalf("Failed to parse: %v", err) + } + defer ast.ReleaseAST(astObj) + + // Verify we have a SELECT statement + if len(astObj.Statements) == 0 { + t.Fatal("No statements parsed") + } + + selectStmt, ok := astObj.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatal("Expected SELECT statement") + } + + // Verify join tree structure + if len(selectStmt.Joins) != 2 { + t.Errorf("Expected 2 joins, got %d", len(selectStmt.Joins)) + } + + // First join: users LEFT JOIN orders + if len(selectStmt.Joins) > 0 { + firstJoin := selectStmt.Joins[0] + if firstJoin.Type != "LEFT" { + t.Errorf("First join type: expected LEFT, got %s", firstJoin.Type) + } + if firstJoin.Left.Name != "users" { + t.Errorf("First join left table: expected users, got %s", firstJoin.Left.Name) + } + if firstJoin.Right.Name != "orders" { + t.Errorf("First join right table: expected orders, got %s", firstJoin.Right.Name) + } + } + + // Second join: (users LEFT JOIN orders) INNER JOIN products + if len(selectStmt.Joins) > 1 { + secondJoin := selectStmt.Joins[1] + if secondJoin.Type != "INNER" { + t.Errorf("Second join type: expected INNER, got %s", secondJoin.Type) + } + // The left side should now represent the result of previous joins + if secondJoin.Left.Name != "(users_with_1_joins)" { + t.Errorf("Second join left table: expected (users_with_1_joins), got %s", secondJoin.Left.Name) + } + if secondJoin.Right.Name != "products" { + t.Errorf("Second join right table: expected products, got %s", secondJoin.Right.Name) + } + } +} diff --git a/pkg/sql/parser/parser.go b/pkg/sql/parser/parser.go index 185d9ee..74383fd 100644 --- a/pkg/sql/parser/parser.go +++ b/pkg/sql/parser/parser.go @@ -58,6 +58,10 @@ func (p *Parser) Release() { // parseStatement parses a single SQL statement func (p *Parser) parseStatement() (ast.Statement, error) { + // TODO: PHASE 2 - Add WITH statement parsing for Common Table Expressions (CTEs) + // case "WITH": + // p.advance() // Consume WITH + // return p.parseWithStatement() // Needs implementation switch p.currentToken.Type { case "SELECT": p.advance() // Consume SELECT @@ -322,9 +326,17 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { } // Check for table alias - if p.currentToken.Type == "IDENT" { - tableRef.Alias = p.currentToken.Literal - p.advance() + if p.currentToken.Type == "IDENT" || p.currentToken.Type == "AS" { + if p.currentToken.Type == "AS" { + p.advance() // Consume AS + if p.currentToken.Type != "IDENT" { + return nil, p.expectedError("alias after AS") + } + } + if p.currentToken.Type == "IDENT" { + tableRef.Alias = p.currentToken.Literal + p.advance() + } } // Create tables list for FROM clause @@ -332,12 +344,45 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { // Parse JOIN clauses if present joins := []ast.JoinClause{} - for p.currentToken.Type == "JOIN" { + for p.isJoinKeyword() { + // Determine JOIN type + joinType := "INNER" // Default + + if p.currentToken.Type == "LEFT" { + joinType = "LEFT" + p.advance() + if p.currentToken.Type == "OUTER" { + p.advance() // Optional OUTER keyword + } + } else if p.currentToken.Type == "RIGHT" { + joinType = "RIGHT" + p.advance() + if p.currentToken.Type == "OUTER" { + p.advance() // Optional OUTER keyword + } + } else if p.currentToken.Type == "FULL" { + joinType = "FULL" + p.advance() + if p.currentToken.Type == "OUTER" { + p.advance() // Optional OUTER keyword + } + } else if p.currentToken.Type == "INNER" { + joinType = "INNER" + p.advance() + } else if p.currentToken.Type == "CROSS" { + joinType = "CROSS" + p.advance() + } + + // Expect JOIN keyword + if p.currentToken.Type != "JOIN" { + return nil, fmt.Errorf("expected JOIN after %s, got %s", joinType, p.currentToken.Type) + } p.advance() // Consume JOIN // Parse joined table name if p.currentToken.Type != "IDENT" { - return nil, p.expectedError("table name after JOIN") + return nil, fmt.Errorf("expected table name after %s JOIN, got %s", joinType, p.currentToken.Type) } joinedTableName := p.currentToken.Literal p.advance() @@ -348,33 +393,89 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { } // Check for table alias - if p.currentToken.Type == "IDENT" { - joinedTableRef.Alias = p.currentToken.Literal - p.advance() + if p.currentToken.Type == "IDENT" || p.currentToken.Type == "AS" { + if p.currentToken.Type == "AS" { + p.advance() // Consume AS + if p.currentToken.Type != "IDENT" { + return nil, p.expectedError("alias after AS") + } + } + if p.currentToken.Type == "IDENT" { + joinedTableRef.Alias = p.currentToken.Literal + p.advance() + } } - // Parse ON clause - if p.currentToken.Type != "ON" { - return nil, p.expectedError("ON") + // Parse join condition (ON or USING) + var joinCondition ast.Expression + + // CROSS JOIN doesn't require ON clause + if joinType != "CROSS" { + if p.currentToken.Type == "ON" { + p.advance() // Consume ON + + // Parse join condition + cond, err := p.parseExpression() + if err != nil { + return nil, fmt.Errorf("error parsing ON condition for %s JOIN: %v", joinType, err) + } + joinCondition = cond + } else if p.currentToken.Type == "USING" { + p.advance() // Consume USING + + // Parse column list in parentheses + if p.currentToken.Type != "(" { + return nil, p.expectedError("( after USING") + } + p.advance() + + // TODO: LIMITATION - Currently only supports single column in USING clause + // Future enhancement needed for multi-column support like USING (col1, col2, col3) + // This requires parsing comma-separated column list and storing as []Expression + // Priority: Medium (Phase 2 enhancement) + if p.currentToken.Type != "IDENT" { + return nil, p.expectedError("column name in USING") + } + joinCondition = &ast.Identifier{Name: p.currentToken.Literal} + p.advance() + + if p.currentToken.Type != ")" { + return nil, p.expectedError(") after USING column") + } + p.advance() + } else if joinType != "NATURAL" { + return nil, p.expectedError("ON or USING") + } } - p.advance() // Consume ON - // Parse join condition - joinCondition, err := p.parseExpression() - if err != nil { - return nil, err + // Create join clause with proper tree relationships + // For SQL: FROM A JOIN B JOIN C (equivalent to (A JOIN B) JOIN C) + var leftTable ast.TableReference + if len(joins) == 0 { + // First join: A JOIN B + leftTable = tableRef + } else { + // Subsequent joins: (previous result) JOIN C + // We represent this by using a synthetic table reference that indicates + // the left side is the result of previous joins + leftTable = ast.TableReference{ + Name: fmt.Sprintf("(%s_with_%d_joins)", tableRef.Name, len(joins)), + Alias: "", + } } - // Create join clause joinClause := ast.JoinClause{ - Type: "INNER", // Default to INNER JOIN - Left: tableRef, + Type: joinType, + Left: leftTable, Right: joinedTableRef, Condition: joinCondition, } // Add join clause to joins list joins = append(joins, joinClause) + + // Note: We don't update tableRef here as each JOIN in the list + // represents a join with the accumulated result set } // Initialize SELECT statement @@ -689,3 +790,13 @@ func (p *Parser) parseAlterTableStmt() (ast.Statement, error) { // This is just a placeholder that delegates to the main implementation return p.parseAlterStatement() } + +// isJoinKeyword checks if current token is a JOIN-related keyword +func (p *Parser) isJoinKeyword() bool { + switch p.currentToken.Type { + case "JOIN", "INNER", "LEFT", "RIGHT", "FULL", "CROSS": + return true + default: + return false + } +} diff --git a/pkg/sql/tokenizer/tokenizer.go b/pkg/sql/tokenizer/tokenizer.go index b6f1621..017eb2e 100644 --- a/pkg/sql/tokenizer/tokenizer.go +++ b/pkg/sql/tokenizer/tokenizer.go @@ -15,44 +15,53 @@ import ( // keywordTokenTypes maps SQL keywords to their token types for fast lookup var keywordTokenTypes = map[string]models.TokenType{ - "SELECT": models.TokenTypeSelect, - "FROM": models.TokenTypeFrom, - "WHERE": models.TokenTypeWhere, - "GROUP": models.TokenTypeGroup, - "ORDER": models.TokenTypeOrder, - "HAVING": models.TokenTypeHaving, - "JOIN": models.TokenTypeJoin, - "INNER": models.TokenTypeInner, - "LEFT": models.TokenTypeLeft, - "RIGHT": models.TokenTypeRight, - "OUTER": models.TokenTypeOuter, - "ON": models.TokenTypeOn, - "AND": models.TokenTypeAnd, - "OR": models.TokenTypeOr, - "NOT": models.TokenTypeNot, - "AS": models.TokenTypeAs, - "BY": models.TokenTypeBy, - "IN": models.TokenTypeIn, - "LIKE": models.TokenTypeLike, - "BETWEEN": models.TokenTypeBetween, - "IS": models.TokenTypeIs, - "NULL": models.TokenTypeNull, - "TRUE": models.TokenTypeTrue, - "FALSE": models.TokenTypeFalse, - "CASE": models.TokenTypeCase, - "WHEN": models.TokenTypeWhen, - "THEN": models.TokenTypeThen, - "ELSE": models.TokenTypeElse, - "END": models.TokenTypeEnd, - "ASC": models.TokenTypeAsc, - "DESC": models.TokenTypeDesc, - "LIMIT": models.TokenTypeLimit, - "OFFSET": models.TokenTypeOffset, - "COUNT": models.TokenTypeCount, - "SUM": models.TokenTypeSum, - "AVG": models.TokenTypeAvg, - "MIN": models.TokenTypeMin, - "MAX": models.TokenTypeMax, + "SELECT": models.TokenTypeSelect, + "FROM": models.TokenTypeFrom, + "WHERE": models.TokenTypeWhere, + "GROUP": models.TokenTypeGroup, + "ORDER": models.TokenTypeOrder, + "HAVING": models.TokenTypeHaving, + "JOIN": models.TokenTypeJoin, + "INNER": models.TokenTypeInner, + "LEFT": models.TokenTypeLeft, + "RIGHT": models.TokenTypeRight, + "OUTER": models.TokenTypeOuter, + "ON": models.TokenTypeOn, + "AND": models.TokenTypeAnd, + "OR": models.TokenTypeOr, + "NOT": models.TokenTypeNot, + "AS": models.TokenTypeAs, + "BY": models.TokenTypeBy, + "IN": models.TokenTypeIn, + "LIKE": models.TokenTypeLike, + "BETWEEN": models.TokenTypeBetween, + "IS": models.TokenTypeIs, + "NULL": models.TokenTypeNull, + "TRUE": models.TokenTypeTrue, + "FALSE": models.TokenTypeFalse, + "CASE": models.TokenTypeCase, + "WHEN": models.TokenTypeWhen, + "THEN": models.TokenTypeThen, + "ELSE": models.TokenTypeElse, + "END": models.TokenTypeEnd, + "ASC": models.TokenTypeAsc, + "DESC": models.TokenTypeDesc, + "LIMIT": models.TokenTypeLimit, + "OFFSET": models.TokenTypeOffset, + "COUNT": models.TokenTypeCount, + "FULL": models.TokenTypeKeyword, + "CROSS": models.TokenTypeKeyword, + "USING": models.TokenTypeKeyword, + "WITH": models.TokenTypeKeyword, + "RECURSIVE": models.TokenTypeKeyword, + "UNION": models.TokenTypeKeyword, + "EXCEPT": models.TokenTypeKeyword, + "INTERSECT": models.TokenTypeKeyword, + "ALL": models.TokenTypeKeyword, + "SUM": models.TokenTypeSum, + "AVG": models.TokenTypeAvg, + "MIN": models.TokenTypeMin, + "MAX": models.TokenTypeMax, } // Tokenizer provides high-performance SQL tokenization with zero-copy operations @@ -334,12 +343,17 @@ var compoundKeywordStarts = map[string]bool{ // compoundKeywordTypes maps compound SQL keywords to their token types var compoundKeywordTypes = map[string]models.TokenType{ - "GROUP BY": models.TokenTypeGroupBy, - "ORDER BY": models.TokenTypeOrderBy, - "LEFT JOIN": models.TokenTypeLeftJoin, - "RIGHT JOIN": models.TokenTypeRightJoin, - "INNER JOIN": models.TokenTypeInnerJoin, - "OUTER JOIN": models.TokenTypeOuterJoin, + "GROUP BY": models.TokenTypeGroupBy, + "ORDER BY": models.TokenTypeOrderBy, + "LEFT JOIN": models.TokenTypeLeftJoin, + "RIGHT JOIN": models.TokenTypeRightJoin, + "INNER JOIN": models.TokenTypeInnerJoin, + "OUTER JOIN": models.TokenTypeOuterJoin, + "FULL JOIN": models.TokenTypeKeyword, + "CROSS JOIN": models.TokenTypeKeyword, + "LEFT OUTER JOIN": models.TokenTypeKeyword, + "RIGHT OUTER JOIN": models.TokenTypeKeyword, + "FULL OUTER JOIN": models.TokenTypeKeyword, } // Helper function to check if a word can start a compound keyword