Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions cmd/gosqlx/cmd/analyze.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"bytes"
"fmt"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -86,17 +87,32 @@ func analyzeRun(cmd *cobra.Command, args []string) error {
Verbose: verbose,
})

// Use a buffer to capture output when writing to file
var outputBuf bytes.Buffer
var outWriter = cmd.OutOrStdout()
if outputFile != "" {
outWriter = &outputBuf
}

// Create analyzer with injectable output writers
analyzer := NewAnalyzer(cmd.OutOrStdout(), cmd.ErrOrStderr(), opts)
analyzer := NewAnalyzer(outWriter, cmd.ErrOrStderr(), opts)

// Run analysis
result, err := analyzer.Analyze(args[0])
if err != nil {
return err
}

// Display the report
return analyzer.DisplayReport(result.Report)
// Display the report (writes to outWriter)
if err := analyzer.DisplayReport(result.Report); err != nil {
return err
}

// Write to file if specified
if outputFile != "" {
return WriteOutput(outputBuf.Bytes(), outputFile, cmd.OutOrStdout())
}
return nil
}

// analyzeFromStdin handles analysis from stdin input
Expand Down Expand Up @@ -139,17 +155,32 @@ func analyzeFromStdin(cmd *cobra.Command) error {
Verbose: verbose,
})

// Use a buffer to capture output when writing to file
var outputBuf bytes.Buffer
var outWriter = cmd.OutOrStdout()
if outputFile != "" {
outWriter = &outputBuf
}

// Create analyzer
analyzer := NewAnalyzer(cmd.OutOrStdout(), cmd.ErrOrStderr(), opts)
analyzer := NewAnalyzer(outWriter, cmd.ErrOrStderr(), opts)

// Analyze the stdin content (Analyze accepts string input directly)
result, err := analyzer.Analyze(string(content))
if err != nil {
return err
}

// Display the report
return analyzer.DisplayReport(result.Report)
// Display the report (writes to outWriter)
if err := analyzer.DisplayReport(result.Report); err != nil {
return err
}

// Write to file if specified
if outputFile != "" {
return WriteOutput(outputBuf.Bytes(), outputFile, cmd.OutOrStdout())
}
return nil
}

func init() {
Expand Down
54 changes: 44 additions & 10 deletions cmd/gosqlx/cmd/lint.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package cmd

import (
"bytes"
"fmt"
"io"
"os"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -79,13 +81,20 @@ func lintRun(cmd *cobra.Command, args []string) error {
result = l.LintFiles(args)
}

// Use a buffer to capture output when writing to file
var outputBuf bytes.Buffer
var outWriter io.Writer = cmd.OutOrStdout()
if outputFile != "" {
outWriter = &outputBuf
}

// Display results
output := linter.FormatResult(result)
fmt.Fprint(cmd.OutOrStdout(), output)
fmt.Fprint(outWriter, output)

// Apply auto-fix if requested
if lintAutoFix && result.TotalViolations > 0 {
fmt.Fprintln(cmd.OutOrStdout(), "\nApplying auto-fixes...")
fmt.Fprintln(outWriter, "\nApplying auto-fixes...")
fixCount := 0

for _, fileResult := range result.Files {
Expand Down Expand Up @@ -134,11 +143,18 @@ func lintRun(cmd *cobra.Command, args []string) error {
continue
}
fixCount++
fmt.Fprintf(cmd.OutOrStdout(), "Fixed: %s\n", fileResult.Filename)
fmt.Fprintf(outWriter, "Fixed: %s\n", fileResult.Filename)
}
}

fmt.Fprintf(cmd.OutOrStdout(), "\nAuto-fixed %d file(s)\n", fixCount)
fmt.Fprintf(outWriter, "\nAuto-fixed %d file(s)\n", fixCount)
}

// Write to file if specified
if outputFile != "" {
if err := WriteOutput(outputBuf.Bytes(), outputFile, cmd.OutOrStdout()); err != nil {
return err
}
}

// Exit with error code if there were violations
Expand Down Expand Up @@ -181,22 +197,33 @@ func lintFromStdin(cmd *cobra.Command) error {
// Lint the content
result := l.LintString(string(content), "stdin")

// Use a buffer to capture output when writing to file
var outputBuf bytes.Buffer
var outWriter io.Writer = cmd.OutOrStdout()
if outputFile != "" {
outWriter = &outputBuf
}

// Display results
fmt.Fprintf(cmd.OutOrStdout(), "Linting stdin input:\n\n")
fmt.Fprintf(outWriter, "Linting stdin input:\n\n")

if len(result.Violations) == 0 {
fmt.Fprintln(cmd.OutOrStdout(), "No violations found.")
fmt.Fprintln(outWriter, "No violations found.")
// Write to file if specified
if outputFile != "" {
return WriteOutput(outputBuf.Bytes(), outputFile, cmd.OutOrStdout())
}
return nil
}

fmt.Fprintf(cmd.OutOrStdout(), "Found %d violation(s):\n\n", len(result.Violations))
fmt.Fprintf(outWriter, "Found %d violation(s):\n\n", len(result.Violations))
for i, violation := range result.Violations {
fmt.Fprintf(cmd.OutOrStdout(), "%d. %s\n", i+1, linter.FormatViolation(violation))
fmt.Fprintf(outWriter, "%d. %s\n", i+1, linter.FormatViolation(violation))
}

// Apply auto-fix if requested
if lintAutoFix {
fmt.Fprintln(cmd.OutOrStdout(), "\nAuto-fixed output:")
fmt.Fprintln(outWriter, "\nAuto-fixed output:")
fixed := string(content)
for _, rule := range l.Rules() {
if rule.CanAutoFix() {
Expand All @@ -206,7 +233,14 @@ func lintFromStdin(cmd *cobra.Command) error {
}
}
}
fmt.Fprintln(cmd.OutOrStdout(), fixed)
fmt.Fprintln(outWriter, fixed)
}

// Write to file if specified
if outputFile != "" {
if err := WriteOutput(outputBuf.Bytes(), outputFile, cmd.OutOrStdout()); err != nil {
return err
}
}

// Exit with error code if there were violations
Expand Down
108 changes: 106 additions & 2 deletions cmd/gosqlx/cmd/sql_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,16 @@ func (f *SQLFormatter) formatExpression(expr ast.Expression) error {
if err := f.formatExpression(e.Expr); err != nil {
return err
}
case *ast.AliasedExpression:
// Handle expr AS alias
if err := f.formatExpression(e.Expr); err != nil {
return err
}
f.builder.WriteString(" ")
f.writeKeyword("AS")
f.builder.WriteString(" ")
// Quote alias if it contains special characters or is a reserved keyword
f.formatIdentifier(e.Alias)
default:
// Fallback for unsupported expressions
f.builder.WriteString(expr.TokenLiteral())
Expand Down Expand Up @@ -769,8 +779,20 @@ func (f *SQLFormatter) formatTableReferences(tables []ast.TableReference) {
}

func (f *SQLFormatter) formatTableReference(table *ast.TableReference) {
f.builder.WriteString(table.Name)
if table.Subquery != nil {
// Format derived table (subquery)
f.builder.WriteString("(")
if err := f.formatSelect(table.Subquery); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to format derived table: %v\n", err)
}
f.builder.WriteString(")")
} else {
// Format regular table name
f.builder.WriteString(table.Name)
}
if table.Alias != "" {
f.builder.WriteString(" ")
f.writeKeyword("AS")
f.builder.WriteString(" " + table.Alias)
}
}
Expand All @@ -786,7 +808,46 @@ func (f *SQLFormatter) formatUpdateExpression(update *ast.UpdateExpression) erro
func (f *SQLFormatter) formatColumnDef(col *ast.ColumnDef) {
f.builder.WriteString(col.Name + " " + col.Type)
for _, constraint := range col.Constraints {
f.builder.WriteString(" " + constraint.Type)
f.builder.WriteString(" ")
f.writeKeyword(constraint.Type)
// Format DEFAULT value if present
if constraint.Type == "DEFAULT" && constraint.Default != nil {
f.builder.WriteString(" ")
if err := f.formatExpression(constraint.Default); err != nil {
// Fallback to token literal on error
f.builder.WriteString(constraint.Default.TokenLiteral())
}
}
// Format CHECK expression if present
if constraint.Type == "CHECK" && constraint.Check != nil {
f.builder.WriteString(" (")
if err := f.formatExpression(constraint.Check); err != nil {
f.builder.WriteString(constraint.Check.TokenLiteral())
}
f.builder.WriteString(")")
}
// Format REFERENCES if present
if constraint.Type == "REFERENCES" && constraint.References != nil {
f.builder.WriteString(" ")
f.builder.WriteString(constraint.References.Table)
if len(constraint.References.Columns) > 0 {
f.builder.WriteString("(")
f.builder.WriteString(strings.Join(constraint.References.Columns, ", "))
f.builder.WriteString(")")
}
if constraint.References.OnDelete != "" {
f.builder.WriteString(" ")
f.writeKeyword("ON DELETE")
f.builder.WriteString(" ")
f.writeKeyword(constraint.References.OnDelete)
}
if constraint.References.OnUpdate != "" {
f.builder.WriteString(" ")
f.writeKeyword("ON UPDATE")
f.builder.WriteString(" ")
f.writeKeyword(constraint.References.OnUpdate)
}
}
}
}

Expand All @@ -810,6 +871,49 @@ func (f *SQLFormatter) writeKeyword(keyword string) {
}
}

// formatIdentifier formats an identifier, quoting it if it contains special characters
// or is a reserved keyword
func (f *SQLFormatter) formatIdentifier(ident string) {
if f.needsQuoting(ident) {
f.builder.WriteString("\"")
// Escape any existing double quotes by doubling them
escaped := strings.ReplaceAll(ident, "\"", "\"\"")
f.builder.WriteString(escaped)
f.builder.WriteString("\"")
} else {
f.builder.WriteString(ident)
}
}

// needsQuoting returns true if the identifier needs to be quoted
func (f *SQLFormatter) needsQuoting(ident string) bool {
if len(ident) == 0 {
return true
}
// Check if it starts with a digit
if ident[0] >= '0' && ident[0] <= '9' {
return true
}
// Check for special characters (allow only letters, digits, and underscore)
for _, c := range ident {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return true
}
}
// Check for common SQL reserved keywords that might be used as aliases
reserved := map[string]bool{
"SELECT": true, "FROM": true, "WHERE": true, "AND": true, "OR": true,
"ORDER": true, "BY": true, "GROUP": true, "HAVING": true, "JOIN": true,
"LEFT": true, "RIGHT": true, "INNER": true, "OUTER": true, "ON": true,
"AS": true, "TABLE": true, "INDEX": true, "CREATE": true, "DROP": true,
"INSERT": true, "UPDATE": true, "DELETE": true, "INTO": true, "VALUES": true,
"SET": true, "NULL": true, "NOT": true, "IN": true, "LIKE": true,
"BETWEEN": true, "EXISTS": true, "CASE": true, "WHEN": true, "THEN": true,
"ELSE": true, "END": true, "DISTINCT": true, "ALL": true, "UNION": true,
}
return reserved[strings.ToUpper(ident)]
}

func (f *SQLFormatter) writeNewline() {
if !f.compact {
f.builder.WriteString("\n")
Expand Down
9 changes: 9 additions & 0 deletions pkg/gosqlx/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ func (cc *columnCollector) collectFromExpression(expr ast.Expression) {
for _, v := range e.Values {
cc.collectFromExpression(v)
}
case *ast.AliasedExpression:
// Unwrap the aliased expression and collect from inner expression
cc.collectFromExpression(e.Expr)
}
}

Expand Down Expand Up @@ -790,6 +793,9 @@ func (qcc *qualifiedColumnCollector) collectFromExpression(expr ast.Expression)
for _, v := range e.Values {
qcc.collectFromExpression(v)
}
case *ast.AliasedExpression:
// Unwrap the aliased expression and collect from inner expression
qcc.collectFromExpression(e.Expr)
}
}

Expand Down Expand Up @@ -953,6 +959,9 @@ func (fc *functionCollector) collectFromExpression(expr ast.Expression) {
for _, v := range e.Values {
fc.collectFromExpression(v)
}
case *ast.AliasedExpression:
// Unwrap the aliased expression and collect from inner expression
fc.collectFromExpression(e.Expr)
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/lsp/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestHandler_DocumentSymbol(t *testing.T) {
{
name: "CREATE TABLE returns SymbolStruct kind",
sql: "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100))",
expectedCount: 0, // Parser may not support CREATE TABLE yet, expect 0
expectedCount: 1, // Parser now supports CREATE TABLE with column constraints
checkSymbols: func(t *testing.T, symbols []DocumentSymbol) {
// Only check if symbols were returned
if len(symbols) > 0 {
Expand Down
Loading
Loading