From 04acbd2359d93298a83d9d428e15e11307c7b019 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sun, 25 Aug 2024 16:46:18 -0400 Subject: [PATCH 1/3] i don't know what i am doing --- .github/workflows/ci.yml | 24 + .github/workflows/release.yml | 10 +- .golangci.yml | 29 +- components/ConnectionForm.go | 23 +- components/ConnectionSelection.go | 36 +- components/Home.go | 36 +- components/ResultsTable.go | 418 ++++++++--------- components/TabbedMenu.go | 49 +- components/Tree.go | 30 +- drivers/constants.go | 5 + drivers/driver.go | 16 +- drivers/mysql.go | 333 ++++++++----- drivers/postgres.go | 757 +++++++++++++++++++++--------- drivers/sqlite.go | 321 +++++++++---- helpers/logger/logger.go | 135 ++++++ main.go | 27 +- models/models.go | 48 +- 17 files changed, 1557 insertions(+), 740 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 drivers/constants.go create mode 100644 helpers/logger/logger.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..96242eb --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,24 @@ +--- +name: Continuous Integration + +on: + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + + - name: Golangci-lint + uses: golangci/golangci-lint-action@v6.0.1 + + - name: Test + run: go test -v ./... + + - name: Build + run: go build -v ./... diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 55fdc11..9720e1f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -22,15 +22,7 @@ jobs: uses: actions/setup-go@v5 - - name: Golangci-lint - uses: golangci/golangci-lint-action@v6.0.1 - - - name: Install build dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential libc6-dev - - - name: Run GoReleaser + name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 with: # either 'goreleaser' (default) or 'goreleaser-pro' diff --git a/.golangci.yml b/.golangci.yml index 55c9e8c..19268fb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -4,7 +4,7 @@ # Author: @ccoVeille # License: MIT # Variant: 03-safe -# Version: v1.0.0 +# Version: v1.0.0 + gosec + sqlclosecheck + rowserrcheck # linters: # some linters are enabled by default @@ -49,6 +49,15 @@ linters: # Checks for duplicate words in the source code. - dupword + # Inspects source code for security problems. + - gosec + + # Checks that sql.Rows, sql.Stmt, sqlx.NamedStmt, pgx.Query are closed. + - sqlclosecheck + + # Checks whether Rows.Err of rows is checked successfully. + - rowserrcheck + linters-settings: gci: # define the section orders for imports sections: @@ -132,12 +141,30 @@ linters-settings: # warns when initialism, variable or package naming conventions are not followed. - name: var-naming + # if-then-else conditional with identical implementations in both branches is an error. + - name: identical-branches + + # warns when errors returned by a function are not explicitly handled on the caller side. + - name: unhandled-error + arguments: # here are the exceptions we don't want to be reported + - "fmt.Print.*" + - "fmt.Fprint.*" + - "bytes.Buffer.Write*" + - "strings.Builder.Write*" + dupword: # Keywords used to ignore detection. # Default: [] ignore: # - "blah" # this will accept "blah blah …" as a valid duplicate word + gosec: + # To specify a set of rules to explicitly exclude. + # Available rules: https://github.com/securego/gosec#available-rules + excludes: + - G306 # Poor file permissions used when writing to a new file + - G307 # Poor file permissions used when creating a file with os.Create + misspell: # Correct spellings using locale preferences for US or UK. # Setting locale to US will correct the British spelling of 'colour' to 'color'. diff --git a/components/ConnectionForm.go b/components/ConnectionForm.go index f6881f8..36d6c26 100644 --- a/components/ConnectionForm.go +++ b/components/ConnectionForm.go @@ -29,31 +29,39 @@ func NewConnectionForm(connectionPages *models.ConnectionPages) *ConnectionForm buttonsWrapper := tview.NewFlex().SetDirection(tview.FlexColumn) - saveButton := tview.NewButton("[darkred]F1 [black]Save") + saveButton := tview.NewButton("[yellow]F1 [dark]Save") saveButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + saveButton.SetBorder(true) + buttonsWrapper.AddItem(saveButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) - testButton := tview.NewButton("[darkred]F2 [black]Test") + testButton := tview.NewButton("[yellow]F2 [dark]Test") testButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + testButton.SetBorder(true) + buttonsWrapper.AddItem(testButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) - connectButton := tview.NewButton("[darkred]F3 [black]Connect") + connectButton := tview.NewButton("[yellow]F3 [dark]Connect") connectButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + connectButton.SetBorder(true) + buttonsWrapper.AddItem(connectButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) - cancelButton := tview.NewButton("[darkred]Esc [black]Cancel") + cancelButton := tview.NewButton("[yellow]Esc [dark]Cancel") cancelButton.SetStyle(tcell.StyleDefault.Background(tcell.Color(tview.Styles.PrimaryTextColor))) + cancelButton.SetBorder(true) + buttonsWrapper.AddItem(cancelButton, 0, 1, false) statusText := tview.NewTextView() - statusText.SetBorderPadding(0, 1, 0, 0) + statusText.SetBorderPadding(1, 1, 0, 0) wrapper.AddItem(addForm, 0, 1, true) - wrapper.AddItem(statusText, 3, 0, false) - wrapper.AddItem(buttonsWrapper, 1, 0, false) + wrapper.AddItem(statusText, 4, 0, false) + wrapper.AddItem(buttonsWrapper, 3, 0, false) form := &ConnectionForm{ Flex: wrapper, @@ -81,7 +89,6 @@ func (form *ConnectionForm) inputCapture(connectionPages *models.ConnectionPages connectionString := form.GetFormItem(1).(*tview.InputField).GetText() parsed, err := helpers.ParseConnectionString(connectionString) - if err != nil { form.StatusText.SetText(err.Error()).SetTextStyle(tcell.StyleDefault.Foreground(tcell.ColorRed)) return event diff --git a/components/ConnectionSelection.go b/components/ConnectionSelection.go index 59b4310..b90fb85 100644 --- a/components/ConnectionSelection.go +++ b/components/ConnectionSelection.go @@ -27,36 +27,44 @@ func NewConnectionSelection(connectionForm *ConnectionForm, connectionPages *mod buttonsWrapper := tview.NewFlex().SetDirection(tview.FlexRowCSS) - newButton := tview.NewButton("[darkred]N[black]ew") - newButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + newButton := tview.NewButton("[yellow]N[dark]ew") + newButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + newButton.SetBorder(true) + buttonsWrapper.AddItem(newButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) - connectButton := tview.NewButton("[darkred]C[black]onnect") - connectButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + connectButton := tview.NewButton("[yellow]C[dark]onnect") + connectButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + connectButton.SetBorder(true) + buttonsWrapper.AddItem(connectButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) - editButton := tview.NewButton("[darkred]E[black]dit") - editButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + editButton := tview.NewButton("[yellow]E[dark]dit") + editButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + editButton.SetBorder(true) + buttonsWrapper.AddItem(editButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) - deleteButton := tview.NewButton("[darkred]D[black]elete") - deleteButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + deleteButton := tview.NewButton("[yellow]D[dark]elete") + deleteButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + deleteButton.SetBorder(true) + buttonsWrapper.AddItem(deleteButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) - quitButton := tview.NewButton("[darkred]Q[black]uit") - quitButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) - buttonsWrapper.AddItem(quitButton, 0, 1, false) + quitButton := tview.NewButton("[yellow]Q[dark]uit") + quitButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + quitButton.SetBorder(true) statusText := tview.NewTextView() - statusText.SetBorderPadding(0, 1, 0, 0) + statusText.SetBorderPadding(1, 1, 0, 0) wrapper.AddItem(ConnectionListTable, 0, 1, true) - wrapper.AddItem(statusText, 3, 0, false) - wrapper.AddItem(buttonsWrapper, 1, 0, false) + wrapper.AddItem(statusText, 4, 0, false) + wrapper.AddItem(buttonsWrapper, 3, 0, false) cs := &ConnectionSelection{ Flex: wrapper, diff --git a/components/Home.go b/components/Home.go index c279a3f..4ceee8e 100644 --- a/components/Home.go +++ b/components/Home.go @@ -1,6 +1,8 @@ package components import ( + "fmt" + "github.com/gdamore/tcell/v2" "github.com/rivo/tview" @@ -19,7 +21,6 @@ type Home struct { DBDriver drivers.Driver FocusedWrapper string ListOfDbChanges []models.DbDmlChange - ListOfDbInserts []models.DbInsert } func NewHomePage(connection models.Connection, dbdriver drivers.Driver) *Home { @@ -35,7 +36,6 @@ func NewHomePage(connection models.Connection, dbdriver drivers.Driver) *Home { LeftWrapper: leftWrapper, RightWrapper: rightWrapper, ListOfDbChanges: []models.DbDmlChange{}, - ListOfDbInserts: []models.DbInsert{}, DBDriver: dbdriver, } @@ -74,19 +74,25 @@ func (home *Home) subscribeToTreeChanges() { for stateChange := range ch { switch stateChange.Key { case "SelectedTable": + databaseName := home.Tree.GetSelectedDatabase() tableName := stateChange.Value.(string) - tab := home.TabbedPane.GetTabByName(tableName) + tabReference := fmt.Sprintf("%s.%s", databaseName, tableName) + + tab := home.TabbedPane.GetTabByReference(tabReference) + var table *ResultsTable if tab != nil { table = tab.Content - home.TabbedPane.SwitchToTabByName(tab.Name) + home.TabbedPane.SwitchToTabByReference(tab.Reference) } else { - table = NewResultsTable(&home.ListOfDbChanges, &home.ListOfDbInserts, home.Tree, home.DBDriver).WithFilter() - table.SetDBReference(tableName) + table = NewResultsTable(&home.ListOfDbChanges, home.Tree, home.DBDriver).WithFilter() + table.SetDatabaseName(databaseName) + table.SetTableName(tableName) + + home.TabbedPane.AppendTab(tableName, table, tabReference) - home.TabbedPane.AppendTab(tableName, table) } table.FetchRecords(func() { @@ -191,7 +197,7 @@ func (home *Home) rightWrapperInputCapture(event *tcell.EventKey) *tcell.EventKe if !table.GetIsFiltering() && !table.GetIsEditing() && !table.GetIsLoading() { home.TabbedPane.RemoveCurrentTab() - if home.TabbedPane.GetLenght() == 0 { + if home.TabbedPane.GetLength() == 0 { home.focusLeftWrapper() return nil } @@ -247,13 +253,14 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { home.focusRightWrapper() } } else if command == commands.SwitchToEditorView { - tab := home.TabbedPane.GetTabByName("Editor") + tab := home.TabbedPane.GetTabByReference("Editor") if tab != nil { - home.TabbedPane.SwitchToTabByName("Editor") + home.TabbedPane.SwitchToTabByReference("Editor") + tab.Content.SetIsFiltering(true) } else { - tableWithEditor := NewResultsTable(&home.ListOfDbChanges, &home.ListOfDbInserts, home.Tree, home.DBDriver).WithEditor() - home.TabbedPane.AppendTab("Editor", tableWithEditor) + tableWithEditor := NewResultsTable(&home.ListOfDbChanges, home.Tree, home.DBDriver).WithEditor() + home.TabbedPane.AppendTab("Editor", tableWithEditor, "Editor") tableWithEditor.SetIsFiltering(true) } home.focusRightWrapper() @@ -273,7 +280,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { App.Stop() } } else if command == commands.Save { - if (home.ListOfDbChanges != nil && len(home.ListOfDbChanges) > 0) || (home.ListOfDbInserts != nil && len(home.ListOfDbInserts) > 0) && !table.GetIsEditing() { + if (home.ListOfDbChanges != nil && len(home.ListOfDbChanges) > 0) && !table.GetIsEditing() { confirmationModal := NewConfirmationModal("") confirmationModal.SetDoneFunc(func(_ int, buttonLabel string) { @@ -282,13 +289,12 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { if buttonLabel == "Yes" { - err := home.DBDriver.ExecutePendingChanges(home.ListOfDbChanges, home.ListOfDbInserts) + err := home.DBDriver.ExecutePendingChanges(home.ListOfDbChanges) if err != nil { table.SetError(err.Error(), nil) } else { home.ListOfDbChanges = []models.DbDmlChange{} - home.ListOfDbInserts = []models.DbInsert{} table.FetchRecords(nil) home.Tree.ForceRemoveHighlight() diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 780a160..29893fb 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -12,15 +12,16 @@ import ( "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/commands" "github.com/jorgerojas26/lazysql/drivers" + "github.com/jorgerojas26/lazysql/helpers/logger" "github.com/jorgerojas26/lazysql/models" ) type ResultsTableState struct { listOfDbChanges *[]models.DbDmlChange - listOfDbInserts *[]models.DbInsert error string currentSort string - dbReference string + databaseName string + tableName string records [][]string columns [][]string constraints [][]string @@ -55,7 +56,7 @@ var ( DeleteColor = tcell.ColorRed ) -func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, listOfDbInserts *[]models.DbInsert, tree *Tree, dbdriver drivers.Driver) *ResultsTable { +func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, tree *Tree, dbdriver drivers.Driver) *ResultsTable { state := &ResultsTableState{ records: [][]string{}, columns: [][]string{}, @@ -65,7 +66,6 @@ func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, listOfDbInserts *[]m isEditing: false, isLoading: false, listOfDbChanges: listOfDbChanges, - listOfDbInserts: listOfDbInserts, } wrapper := tview.NewFlex() @@ -110,6 +110,8 @@ func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, listOfDbInserts *[]m table.SetInputCapture(table.tableInputCapture) table.SetSelectedStyle(tcell.StyleDefault.Background(tview.Styles.SecondaryTextColor).Foreground(tview.Styles.ContrastSecondaryTextColor)) + go table.subscribeToTreeChanges() + return table } @@ -173,6 +175,16 @@ func (table *ResultsTable) WithEditor() *ResultsTable { return table } +func (table *ResultsTable) subscribeToTreeChanges() { + ch := table.Tree.Subscribe() + + for stateChange := range ch { + if stateChange.Key == "SelectedDatabase" { + table.SetDatabaseName(stateChange.Value.(string)) + } + } +} + func (table *ResultsTable) AddRows(rows [][]string) { for i, row := range rows { for j, cell := range row { @@ -180,11 +192,7 @@ func (table *ResultsTable) AddRows(rows [][]string) { tableCell.SetSelectable(i > 0) tableCell.SetExpansion(1) - if i == 0 { - tableCell.SetTextColor(tview.Styles.PrimaryTextColor) - } else { - tableCell.SetTextColor(tview.Styles.PrimaryTextColor) - } + tableCell.SetTextColor(tview.Styles.PrimaryTextColor) table.SetCell(i, j, tableCell) } @@ -192,12 +200,19 @@ func (table *ResultsTable) AddRows(rows [][]string) { } func (table *ResultsTable) AddInsertedRows() { - inserts := *table.state.listOfDbInserts - rows := make([][]string, len(inserts)) + inserts := make([]models.DbDmlChange, 0) + + for _, change := range *table.state.listOfDbChanges { + if change.Type == models.DmlInsertType { + inserts = append(inserts, change) + } + } + + rows := make([][]models.CellValue, len(inserts)) if len(inserts) > 0 { for i, insert := range inserts { - if insert.Table == table.GetDBReference() && insert.Option == table.Menu.GetSelectedOption() { + if insert.Table == table.GetTableName() { rows[i] = insert.Values } } @@ -208,7 +223,7 @@ func (table *ResultsTable) AddInsertedRows() { rowIndex := rowCount + i for j, cell := range row { - tableCell := tview.NewTableCell(cell) + tableCell := tview.NewTableCell(cell.Value.(string)) tableCell.SetExpansion(1) tableCell.SetReference(inserts[i].PrimaryKeyValue) @@ -220,18 +235,27 @@ func (table *ResultsTable) AddInsertedRows() { } } -func (table *ResultsTable) InsertRow(cols []string, index int, UUID uuid.UUID) { - for i, cell := range cols { - tableCell := tview.NewTableCell(cell) +func (table *ResultsTable) AppendNewRow(cells []models.CellValue, index int, UUID string) { + for i, cell := range cells { + tableCell := tview.NewTableCell(cell.Value.(string)) tableCell.SetExpansion(1) - - if i == 0 { - tableCell.SetReference(UUID) - } + tableCell.SetReference(UUID) tableCell.SetTextColor(tview.Styles.PrimaryTextColor) + tableCell.SetBackgroundColor(tcell.ColorDarkGreen) + + switch cell.Type { + case models.Null: + case models.Default: + case models.String: + tableCell.SetText("") + tableCell.SetTextColor(tview.Styles.InverseTextColor) + } table.SetCell(index, i, tableCell) } + + table.Select(index, 0) + App.ForceDraw() } func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.EventKey { @@ -249,6 +273,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event if eventKey == '1' { table.Menu.SetSelectedOption(1) table.UpdateRows(table.GetRecords()) + table.AddInsertedRows() } else if eventKey == '2' { table.Menu.SetSelectedOption(2) table.UpdateRows(table.GetColumns()) @@ -266,51 +291,33 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event command := app.Keymaps.Group("table").Resolve(event) - if command == commands.AppendNewRow { - if table.Menu.GetSelectedOption() == 1 { - - newRow := make([]string, table.GetColumnCount()) - newRowIndex := table.GetRowCount() - newRowUUID := uuid.New() + if command == commands.AppendNewRow && (table.Menu != nil && table.Menu.GetSelectedOption() == 1) { + dbColumns := table.GetColumns() + newRowTableIndex := table.GetRowCount() + newRowUUID := uuid.New().String() + newRow := make([]models.CellValue, len(dbColumns)-1) - for i := 0; i < table.GetColumnCount(); i++ { - newRow[i] = "Default" + for i, column := range dbColumns { + if i != 0 { // Skip the first row because they are the column names (e.x "Field", "Type", "Null", "Key", "Default", "Extra") + newRow[i-1] = models.CellValue{Type: models.Default, Column: column[0], Value: "DEFAULT"} } + } - table.InsertRow(newRow, newRowIndex, newRowUUID) - - for i := 0; i < table.GetColumnCount(); i++ { - table.GetCell(newRowIndex, i).SetBackgroundColor(tcell.ColorDarkGreen) - } - - newInsert := models.DbInsert{ - Table: table.GetDBReference(), - Columns: table.GetRecords()[0], - Values: newRow, - PrimaryKeyValue: newRowUUID, - Option: 1, - } - - *table.state.listOfDbInserts = append(*table.state.listOfDbInserts, newInsert) - - if table.Tree.GetCurrentNode().GetColor() == tview.Styles.InverseTextColor || table.Tree.GetCurrentNode().GetColor() == tview.Styles.PrimaryTextColor { - table.Tree.GetCurrentNode().SetColor(InsertColor) - } else if table.Tree.GetCurrentNode().GetColor() == DeleteColor { - table.Tree.GetCurrentNode().SetColor(ChangeColor) - } + newInsert := models.DbDmlChange{ + Type: models.DmlInsertType, + Database: table.GetDatabaseName(), + Table: table.GetTableName(), + Values: newRow, + PrimaryKeyColumnName: "", + PrimaryKeyValue: newRowUUID, + } - table.Select(newRowIndex, 0) + *table.state.listOfDbChanges = append(*table.state.listOfDbChanges, newInsert) - App.ForceDraw() - table.StartEditingCell(newRowIndex, 0, func(newValue string, row, col int) { - cellReference := table.GetCell(row, 0).GetReference() + table.AppendNewRow(newRow, newRowTableIndex, newRowUUID) - if cellReference != nil { - table.MutateInsertedRowCell(cellReference.(uuid.UUID), col, newValue) - } - }) + table.StartEditingCell(newRowTableIndex, 0, nil) - } } else if command == commands.Search { if table.Editor != nil { App.SetFocus(table.Editor) @@ -416,13 +423,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event } if command == commands.Edit { - table.StartEditingCell(selectedRowIndex, selectedColumnIndex, func(newValue string, row, col int) { - cellReference := table.GetCell(row, 0).GetReference() - - if cellReference != nil { - table.MutateInsertedRowCell(cellReference.(uuid.UUID), col, newValue) - } - }) + table.StartEditingCell(selectedRowIndex, selectedColumnIndex, nil) } else if command == commands.GotoNext { if selectedColumnIndex+1 < colCount { table.Select(selectedRowIndex, selectedColumnIndex+1) @@ -456,17 +457,17 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event isAnInsertedRow := false indexOfInsertedRow := -1 - for i, insertedRow := range *table.state.listOfDbInserts { + for i, insertedRow := range *table.state.listOfDbChanges { cellReference := table.GetCell(selectedRowIndex, 0).GetReference() - if cellReference != nil && insertedRow.PrimaryKeyValue.String() == cellReference.(uuid.UUID).String() { + if cellReference != nil && insertedRow.PrimaryKeyValue == cellReference.(string) { isAnInsertedRow = true indexOfInsertedRow = i } } if isAnInsertedRow { - *table.state.listOfDbInserts = append((*table.state.listOfDbInserts)[:indexOfInsertedRow], (*table.state.listOfDbInserts)[indexOfInsertedRow+1:]...) + *table.state.listOfDbChanges = append((*table.state.listOfDbChanges)[:indexOfInsertedRow], (*table.state.listOfDbChanges)[indexOfInsertedRow+1:]...) table.RemoveRow(selectedRowIndex) if selectedRowIndex-1 != 0 { table.Select(selectedRowIndex-1, 0) @@ -475,16 +476,8 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event table.Select(selectedRowIndex+1, 0) } } - - // if len(*table.state.listOfDbChanges) == 0 && len(*table.state.listOfDbInserts) == 0 { - // table.Tree.ForceRemoveHighlight() - // } else if len(*table.state.listOfDbChanges) == 0 && len(*table.state.listOfDbInserts) > 0 { - // table.Tree.GetCurrentNode().SetColor(InsertColor) - // } else if len(*table.state.listOfDbChanges) > 0 && len(*table.state.listOfDbInserts) == 0 { - // table.Tree.GetCurrentNode().SetColor(ChangeColor) - // } } else { - table.AppendNewChange("DELETE", table.GetDBReference(), selectedRowIndex, -1, "") + table.AppendNewChange(models.DmlDeleteType, table.GetDatabaseName(), table.GetTableName(), selectedRowIndex, -1, models.CellValue{}) } } @@ -523,7 +516,6 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event func (table *ResultsTable) UpdateRows(rows [][]string) { table.Clear() table.AddRows(rows) - table.AddInsertedRows() App.ForceDraw() table.Select(1, 0) } @@ -710,8 +702,16 @@ func (table *ResultsTable) GetForeignKeys() [][]string { return table.state.foreignKeys } -func (table *ResultsTable) GetDBReference() string { - return table.state.dbReference +func (table *ResultsTable) GetTableName() string { + return table.state.tableName +} + +func (table *ResultsTable) GetDatabaseName() string { + return table.state.databaseName +} + +func (table *ResultsTable) GetDatabaseAndTableName() string { + return fmt.Sprintf("%s.%s", table.GetDatabaseName(), table.GetTableName()) } func (table *ResultsTable) GetIsEditing() bool { @@ -765,8 +765,12 @@ func (table *ResultsTable) SetIndexes(indexes [][]string) { table.state.indexes = indexes } -func (table *ResultsTable) SetDBReference(dbReference string) { - table.state.dbReference = dbReference +func (table *ResultsTable) SetDatabaseName(databaseName string) { + table.state.databaseName = databaseName +} + +func (table *ResultsTable) SetTableName(tableName string) { + table.state.tableName = tableName } func (table *ResultsTable) SetError(err string, done func()) { @@ -799,6 +803,18 @@ func (table *ResultsTable) SetResultsInfo(text string) { } func (table *ResultsTable) SetLoading(show bool) { + defer func() { + if r := recover(); r != nil { + logger.Error("ResultsTable.go:800 => Recovered from panic", map[string]any{"error": r}) + _ = table.Page.HidePage("loading") + if table.state.error != "" { + App.SetFocus(table.Error) + } else { + App.SetFocus(table) + } + } + }() + table.state.isLoading = show if show { table.Page.ShowPage("loading") @@ -836,7 +852,7 @@ func (table *ResultsTable) SetSortedBy(column string, direction string) { where = table.Filter.GetCurrentFilter() } table.SetLoading(true) - records, _, err := table.DBDriver.GetRecords(table.GetDBReference(), where, sort, table.Pagination.GetOffset(), table.Pagination.GetLimit()) + records, _, err := table.DBDriver.GetRecords(table.GetDatabaseName(), table.GetTableName(), where, sort, table.Pagination.GetOffset(), table.Pagination.GetLimit()) table.SetLoading(false) if err != nil { @@ -874,7 +890,8 @@ func (table *ResultsTable) SetSortedBy(column string, direction string) { } func (table *ResultsTable) FetchRecords(onError func()) [][]string { - tableName := table.GetDBReference() + tableName := table.GetTableName() + databaseName := table.GetDatabaseName() table.SetLoading(true) @@ -884,7 +901,7 @@ func (table *ResultsTable) FetchRecords(onError func()) [][]string { } sort := table.GetCurrentSort() - records, totalRecords, err := table.DBDriver.GetRecords(tableName, where, sort, table.Pagination.GetOffset(), table.Pagination.GetLimit()) + records, totalRecords, err := table.DBDriver.GetRecords(databaseName, tableName, where, sort, table.Pagination.GetOffset(), table.Pagination.GetLimit()) if err != nil { table.SetError(err.Error(), onError) @@ -894,10 +911,10 @@ func (table *ResultsTable) FetchRecords(onError func()) [][]string { table.SetIsFiltering(false) } - columns, _ := table.DBDriver.GetTableColumns(table.Tree.GetSelectedDatabase(), tableName) - constraints, _ := table.DBDriver.GetConstraints(tableName) - foreignKeys, _ := table.DBDriver.GetForeignKeys(tableName) - indexes, _ := table.DBDriver.GetIndexes(tableName) + columns, _ := table.DBDriver.GetTableColumns(databaseName, tableName) + constraints, _ := table.DBDriver.GetConstraints(databaseName, tableName) + foreignKeys, _ := table.DBDriver.GetForeignKeys(databaseName, tableName) + indexes, _ := table.DBDriver.GetIndexes(databaseName, tableName) if len(records) > 0 { table.SetRecords(records) @@ -907,7 +924,6 @@ func (table *ResultsTable) FetchRecords(onError func()) [][]string { table.SetConstraints(constraints) table.SetForeignKeys(foreignKeys) table.SetIndexes(indexes) - table.SetDBReference(tableName) table.Select(1, 0) table.Pagination.SetTotalRecords(totalRecords) @@ -934,34 +950,35 @@ func (table *ResultsTable) StartEditingCell(row int, col int, callback func(newV table.SetIsEditing(false) currentValue := cell.Text newValue := inputField.GetText() - if key == tcell.KeyEnter { - if currentValue != newValue { + columnName := table.GetCell(0, col).Text - cell.SetText(inputField.GetText()) - - table.AppendNewChange("UPDATE", table.GetDBReference(), row, col, newValue) + if key != tcell.KeyEscape { + cell.SetText(newValue) + if currentValue != newValue { + table.AppendNewChange(models.DmlUpdateType, table.GetDatabaseName(), table.GetTableName(), row, col, models.CellValue{Type: models.String, Value: newValue, Column: columnName}) } - } else if key == tcell.KeyTab { - nextEditableColumnIndex := col + 1 - if nextEditableColumnIndex <= table.GetColumnCount()-1 { - cell.SetText(inputField.GetText()) - table.Select(row, nextEditableColumnIndex) + if key == tcell.KeyTab { + nextEditableColumnIndex := col + 1 - table.StartEditingCell(row, nextEditableColumnIndex, callback) + if nextEditableColumnIndex <= table.GetColumnCount()-1 { + table.Select(row, nextEditableColumnIndex) - } - } else if key == tcell.KeyBacktab { - nextEditableColumnIndex := col - 1 + table.StartEditingCell(row, nextEditableColumnIndex, callback) + + } + } else if key == tcell.KeyBacktab { + nextEditableColumnIndex := col - 1 - if nextEditableColumnIndex >= 0 { - cell.SetText(inputField.GetText()) - table.Select(row, nextEditableColumnIndex) + if nextEditableColumnIndex >= 0 { + table.Select(row, nextEditableColumnIndex) - table.StartEditingCell(row, nextEditableColumnIndex, callback) + table.StartEditingCell(row, nextEditableColumnIndex, callback) + } } + } if key == tcell.KeyEnter || key == tcell.KeyEscape { @@ -981,9 +998,9 @@ func (table *ResultsTable) StartEditingCell(row int, col int, callback func(newV App.SetFocus(inputField) } -func (table *ResultsTable) CheckIfRowIsInserted(rowID uuid.UUID) bool { - for _, insertedRow := range *table.state.listOfDbInserts { - if insertedRow.PrimaryKeyValue == rowID { +func (table *ResultsTable) CheckIfRowIsInserted(rowID string) bool { + for _, dmlChange := range *table.state.listOfDbChanges { + if dmlChange.Type == models.DmlInsertType && dmlChange.PrimaryKeyValue == rowID { return true } } @@ -991,129 +1008,98 @@ func (table *ResultsTable) CheckIfRowIsInserted(rowID uuid.UUID) bool { return false } -func (table *ResultsTable) MutateInsertedRowCell(rowID uuid.UUID, colIndex int, newValue string) { - for i, insertedRow := range *table.state.listOfDbInserts { - if insertedRow.PrimaryKeyValue == rowID { - (*table.state.listOfDbInserts)[i].Values[colIndex] = newValue +func (table *ResultsTable) MutateInsertedRowCell(rowID string, newValue models.CellValue) { + for i, dmlChange := range *table.state.listOfDbChanges { + if dmlChange.PrimaryKeyValue == rowID && dmlChange.Type == models.DmlInsertType { + for j, v := range dmlChange.Values { + if v.Column == newValue.Column { + (*table.state.listOfDbChanges)[i].Values[j] = newValue + break + } + } } } } -// TODO: encapsulate logic for different changeType -func (table *ResultsTable) AppendNewChange(changeType string, tableName string, rowIndex int, colIndex int, value string) { - // check if there is already a change row in the listOfDbChanges variable - // if there is, update the value - // if there isn't, append a new change row - // if the value is the same as the original value, remove the change row +func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseName, tableName string, rowIndex int, colIndex int, value models.CellValue) { + dmlChangeAlreadyExists := false - cellReference := table.GetCell(rowIndex, 0).GetReference() + // If the column has a reference, it means it's an inserted rowIndex + // These is maybe a better way to detect it is an inserted row + tableCell := table.GetCell(rowIndex, colIndex) + tableCellReference := tableCell.GetReference() - isInsertedRow := false + isAnInsertedRow := tableCellReference != nil - if cellReference != nil { - isInsertedRow = table.CheckIfRowIsInserted(cellReference.(uuid.UUID)) + if isAnInsertedRow { + table.MutateInsertedRowCell(tableCellReference.(string), value) + return } - if !isInsertedRow { - primaryKeyValue, primaryKeyColumnName := table.GetPrimaryKeyValue(rowIndex) + primaryKeyValue, primaryKeyColumnName := table.GetPrimaryKeyValue(rowIndex) + + for i, dmlChange := range *table.state.listOfDbChanges { + if dmlChange.Table == tableName && dmlChange.Type == changeType && dmlChange.PrimaryKeyValue == primaryKeyValue { + dmlChangeAlreadyExists = true - alreadyExists := false - indexOfChange := -1 + changeForColExists := false + valueIndex := -1 - for i, change := range *table.state.listOfDbChanges { - if change.PrimaryKeyValue == primaryKeyValue && change.Column == table.GetColumnNameByIndex(colIndex) { - alreadyExists = true - indexOfChange = i + for j, v := range dmlChange.Values { + if v.Column == value.Column { + changeForColExists = true + valueIndex = j + break + } } - } - switch changeType { - case "UPDATE": - cell := table.GetCell(rowIndex, colIndex) - columnName := table.GetColumnNameByIndex(colIndex) - originalCellValue := table.GetRecords()[rowIndex][colIndex] - - if alreadyExists { - if value == originalCellValue { - *table.state.listOfDbChanges = append((*table.state.listOfDbChanges)[:indexOfChange], (*table.state.listOfDbChanges)[indexOfChange+1:]...) - - cell.SetBackgroundColor(tcell.ColorDefault) - cell.SetTextColor(tview.Styles.PrimaryTextColor) - - // if len(*table.state.listOfDbChanges) == 0 && len(*table.state.listOfDbInserts) == 0 { - // table.Tree.GetCurrentNode().SetColor(tview.Styles.InverseTextColor) - // } else if len(*table.state.listOfDbChanges) == 0 && len(*table.state.listOfDbInserts) > 0 { - // table.Tree.GetCurrentNode().SetColor(InsertColor) - // } else if len(*table.state.listOfDbChanges) > 0 && len(*table.state.listOfDbInserts) == 0 { - // table.Tree.GetCurrentNode().SetColor(ChangeColor) - // } + switch changeType { + case models.DmlUpdateType: + originalValue := table.GetRecords()[rowIndex][colIndex] + if changeForColExists { + if originalValue == value.Value { + if len((*table.state.listOfDbChanges)[i].Values) == 1 { + *table.state.listOfDbChanges = append((*table.state.listOfDbChanges)[:i], (*table.state.listOfDbChanges)[i+1:]...) + } else { + (*table.state.listOfDbChanges)[i].Values = append((*table.state.listOfDbChanges)[i].Values[:valueIndex], (*table.state.listOfDbChanges)[i].Values[valueIndex+1:]...) + } + table.SetCellColor(rowIndex, colIndex, tview.Styles.PrimitiveBackgroundColor) + } else { + (*table.state.listOfDbChanges)[i].Values[valueIndex] = value + } } else { - cell.SetBackgroundColor(tcell.ColorOrange.TrueColor()) - cell.SetTextColor(tcell.ColorBlack.TrueColor()) - // table.Tree.GetCurrentNode().SetColor(ChangeColor) - - (*table.state.listOfDbChanges)[indexOfChange].Value = value - } - } else { - newChange := models.DbDmlChange{ - Type: changeType, - Table: tableName, - Column: columnName, - Value: value, - PrimaryKeyColumnName: primaryKeyColumnName, - PrimaryKeyValue: primaryKeyValue, - Option: 1, + (*table.state.listOfDbChanges)[i].Values = append((*table.state.listOfDbChanges)[i].Values, value) + table.SetCellColor(rowIndex, colIndex, ChangeColor) } - *table.state.listOfDbChanges = append(*table.state.listOfDbChanges, newChange) - - cell.SetBackgroundColor(tcell.ColorOrange.TrueColor()) - cell.SetTextColor(tcell.ColorBlack.TrueColor()) - // table.Tree.GetCurrentNode().SetColor(ChangeColor) + case models.DmlDeleteType: + *table.state.listOfDbChanges = append((*table.state.listOfDbChanges)[:i], (*table.state.listOfDbChanges)[i+1:]...) + table.SetRowColor(rowIndex, tview.Styles.PrimitiveBackgroundColor) } - case "DELETE": - if alreadyExists { - - *table.state.listOfDbChanges = append((*table.state.listOfDbChanges)[:indexOfChange], (*table.state.listOfDbChanges)[indexOfChange+1:]...) - - // if len(*table.state.listOfDbChanges) == 0 && len(*table.state.listOfDbInserts) == 0 { - // table.Tree.GetCurrentNode().SetColor(tview.Styles.InverseTextColor) - // } else if len(*table.state.listOfDbChanges) == 0 && len(*table.state.listOfDbInserts) > 0 { - // table.Tree.GetCurrentNode().SetColor(InsertColor) - // } else if len(*table.state.listOfDbChanges) > 0 && len(*table.state.listOfDbInserts) == 0 { - // table.Tree.GetCurrentNode().SetColor(ChangeColor) - // } + } + } - for i := 0; i < table.GetColumnCount(); i++ { - table.GetCell(rowIndex, i).SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor) - } + if !dmlChangeAlreadyExists { - } else { + switch changeType { + case models.DmlDeleteType: + table.SetRowColor(rowIndex, DeleteColor) + case models.DmlUpdateType: + table.SetCellColor(rowIndex, colIndex, ChangeColor) + } - // if table.Tree.GetCurrentNode().GetColor() == tview.Styles.InverseTextColor || table.Tree.GetCurrentNode().GetColor() == tview.Styles.PrimaryTextColor { - // table.Tree.GetCurrentNode().SetColor(DeleteColor) - // } else if table.Tree.GetCurrentNode().GetColor() == InsertColor { - // table.Tree.GetCurrentNode().SetColor(ChangeColor) - // } - - newChange := models.DbDmlChange{ - Type: changeType, - Table: tableName, - Column: "", - Value: "", - PrimaryKeyColumnName: primaryKeyColumnName, - PrimaryKeyValue: primaryKeyValue, - Option: 1, - } + newDmlChange := models.DbDmlChange{ + Type: changeType, + Database: databaseName, + Table: tableName, + Values: []models.CellValue{value}, + PrimaryKeyColumnName: primaryKeyColumnName, + PrimaryKeyValue: primaryKeyValue, + } - *table.state.listOfDbChanges = append(*table.state.listOfDbChanges, newChange) + *table.state.listOfDbChanges = append(*table.state.listOfDbChanges, newDmlChange) - for i := 0; i < table.GetColumnCount(); i++ { - table.GetCell(rowIndex, i).SetBackgroundColor(DeleteColor) - } - } - } } } @@ -1213,3 +1199,13 @@ func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { return primaryKeyValue, primaryKeyColumnName } + +func (table *ResultsTable) SetRowColor(rowIndex int, color tcell.Color) { + for i := 0; i < table.GetColumnCount(); i++ { + table.GetCell(rowIndex, i).SetBackgroundColor(color) + } +} + +func (table *ResultsTable) SetCellColor(rowIndex int, colIndex int, color tcell.Color) { + table.GetCell(rowIndex, colIndex).SetBackgroundColor(color) +} diff --git a/components/TabbedMenu.go b/components/TabbedMenu.go index 73ac679..75dfd0f 100644 --- a/components/TabbedMenu.go +++ b/components/TabbedMenu.go @@ -16,6 +16,7 @@ type Tab struct { PreviousTab *Tab Header *Header Name string + Reference string } type TabbedPaneState struct { @@ -42,15 +43,16 @@ func NewTabbedPane() *TabbedPane { } } -func (t *TabbedPane) AppendTab(name string, content *ResultsTable) { +func (t *TabbedPane) AppendTab(name string, content *ResultsTable, reference string) { textView := tview.NewTextView() textView.SetText(name) item := &Header{textView} newTab := &Tab{ - Content: content, - Name: name, - Header: item, + Content: content, + Name: name, + Header: item, + Reference: reference, } t.state.Length++ @@ -70,7 +72,7 @@ func (t *TabbedPane) AppendTab(name string, content *ResultsTable) { t.HighlightTabHeader(newTab) - t.AddAndSwitchToPage(name, content.Page, true) + t.AddAndSwitchToPage(reference, content.Page, true) } func (t *TabbedPane) RemoveCurrentTab() { @@ -78,7 +80,7 @@ func (t *TabbedPane) RemoveCurrentTab() { if currentTab != nil { t.HeaderContainer.RemoveItem(currentTab.Header) - t.RemovePage(currentTab.Name) + t.RemovePage(currentTab.Reference) t.state.Length-- @@ -114,7 +116,7 @@ func (t *TabbedPane) SetCurrentTab(tab *Tab) *Tab { t.state.CurrentTab = tab t.HighlightTabHeader(tab) - t.SwitchToPage(tab.Name) + t.SwitchToPage(tab.Reference) app.App.SetFocus(tab.Content.Page) @@ -137,7 +139,20 @@ func (t *TabbedPane) GetTabByName(name string) *Tab { return tab } -func (t *TabbedPane) GetLenght() int { +func (t *TabbedPane) GetTabByReference(reference string) *Tab { + tab := t.state.FirstTab + + for i := 0; tab != nil && i < t.state.Length; i++ { + if tab.Reference == reference { + break + } + tab = tab.NextTab + } + + return tab +} + +func (t *TabbedPane) GetLength() int { return t.state.Length } @@ -203,6 +218,24 @@ func (t *TabbedPane) SwitchToTabByName(name string) *Tab { return nil } +func (t *TabbedPane) SwitchToTabByReference(reference string) *Tab { + tab := t.state.FirstTab + + for i := 0; tab != nil && i < t.state.Length; i++ { + if tab.Reference == reference { + break + } + tab = tab.NextTab + } + + if tab != nil { + t.SetCurrentTab(tab) + return tab + } + + return nil +} + func (t *TabbedPane) HighlightTabHeader(tab *Tab) { tabToHighlight := t.state.FirstTab diff --git a/components/Tree.go b/components/Tree.go index d8b294c..36e1010 100644 --- a/components/Tree.go +++ b/components/Tree.go @@ -112,16 +112,32 @@ func NewTree(dbName string, dbdriver drivers.Driver) *Tree { } } else if node.GetLevel() == 2 { if node.GetChildren() == nil { - tableName := node.GetReference().(string) + nodeReference := node.GetReference().(string) + split := strings.Split(nodeReference, ".") + databaseName := "" + tableName := "" + + if len(split) == 1 { + tableName = split[0] + } else if len(split) > 1 { + databaseName = split[0] + tableName = split[1] + } + tree.SetSelectedDatabase(databaseName) tree.SetSelectedTable(tableName) } else { node.SetExpanded(!node.IsExpanded()) } } else if node.GetLevel() == 3 { - tableName := node.GetReference().(string) - - tree.SetSelectedTable(tableName) + nodeReference := node.GetReference().(string) + split := strings.Split(nodeReference, ".") + databaseName := split[0] + schemaName := split[1] + tableName := split[2] + + tree.SetSelectedDatabase(databaseName) + tree.SetSelectedTable(fmt.Sprintf("%s.%s", schemaName, tableName)) } }) @@ -163,7 +179,9 @@ func (tree *Tree) updateNodes(children map[string][]string, node *tview.TreeNode for key, values := range children { var rootNode *tview.TreeNode - if key != node.GetReference().(string) { + nodeReference := node.GetReference().(string) + + if key != nodeReference { rootNode = tview.NewTreeNode(key) rootNode.SetExpanded(false) rootNode.SetReference(key) @@ -177,6 +195,8 @@ func (tree *Tree) updateNodes(children map[string][]string, node *tview.TreeNode childNode.SetColor(tview.Styles.PrimaryTextColor) if tree.DBDriver.GetProvider() == "sqlite3" { childNode.SetReference(child) + } else if tree.DBDriver.GetProvider() == "postgres" { + childNode.SetReference(fmt.Sprintf("%s.%s.%s", nodeReference, key, child)) } else { childNode.SetReference(fmt.Sprintf("%s.%s", key, child)) } diff --git a/drivers/constants.go b/drivers/constants.go new file mode 100644 index 0000000..5f46aaa --- /dev/null +++ b/drivers/constants.go @@ -0,0 +1,5 @@ +package drivers + +const ( + DefaultRowLimit = 300 +) diff --git a/drivers/driver.go b/drivers/driver.go index 7d1e8e3..4c54a88 100644 --- a/drivers/driver.go +++ b/drivers/driver.go @@ -10,15 +10,15 @@ type Driver interface { GetDatabases() ([]string, error) GetTables(database string) (map[string][]string, error) GetTableColumns(database, table string) ([][]string, error) - GetConstraints(table string) ([][]string, error) - GetForeignKeys(table string) ([][]string, error) - GetIndexes(table string) ([][]string, error) - GetRecords(table, where, sort string, offset, limit int) ([][]string, int, error) - UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) error - DeleteRecord(table string, primaryKeyColumnName, primaryKeyValue string) error + GetConstraints(database, table string) ([][]string, error) + GetForeignKeys(database, table string) ([][]string, error) + GetIndexes(database, table string) ([][]string, error) + GetRecords(database, table, where, sort string, offset, limit int) ([][]string, int, error) + UpdateRecord(database, table, column, value, primaryKeyColumnName, primaryKeyValue string) error + DeleteRecord(database, table string, primaryKeyColumnName, primaryKeyValue string) error ExecuteDMLStatement(query string) (string, error) ExecuteQuery(query string) ([][]string, error) - ExecutePendingChanges(changes []models.DbDmlChange, inserts []models.DbInsert) error - SetProvider(provider string) + ExecutePendingChanges(changes []models.DbDmlChange) error + SetProvider(provider string) // NOTE: This is used to get the primary key from the database table until i find a better way to do it. See ResultsTable.go GetPrimaryKeyValue function GetProvider() string } diff --git a/drivers/mysql.go b/drivers/mysql.go index d8611b7..05472b0 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -4,11 +4,11 @@ import ( "database/sql" "errors" "fmt" - "strconv" "strings" "github.com/xo/dburl" + "github.com/jorgerojas26/lazysql/helpers/logger" "github.com/jorgerojas26/lazysql/models" ) @@ -45,6 +45,13 @@ func (db *MySQL) GetDatabases() ([]string, error) { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + + defer rows.Close() + for rows.Next() { var database string err := rows.Scan(&database) @@ -60,8 +67,19 @@ func (db *MySQL) GetDatabases() ([]string, error) { } func (db *MySQL) GetTables(database string) (map[string][]string, error) { + if database == "" { + return nil, errors.New("database name is required") + } + rows, err := db.Connection.Query(fmt.Sprintf("SHOW TABLES FROM `%s`", database)) + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + + defer rows.Close() + tables := make(map[string][]string) if err != nil { @@ -81,14 +99,28 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) { return tables, nil } -// TODO: Rewrite this logic to use the database name instead of the table name, which for now has the format `database.table` -func (db *MySQL) GetTableColumns(_, table string) (results [][]string, err error) { - table = db.formatTableName(table) +func (db *MySQL) GetTableColumns(database, table string) (results [][]string, err error) { + if database == "" { + return nil, errors.New("database name is required") + } + + if table == "" { + return nil, errors.New("table name is required") + } + + query := "DESCRIBE " + query += db.formatTableName(database, table) - rows, err := db.Connection.Query(fmt.Sprintf("DESCRIBE %s", table)) + rows, err := db.Connection.Query(query) if err != nil { return nil, err } + + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -100,6 +132,7 @@ func (db *MySQL) GetTableColumns(_, table string) (results [][]string, err error for rows.Next() { rowValues := make([]interface{}, len(columns)) + for i := range columns { rowValues[i] = new(sql.RawBytes) } @@ -120,18 +153,27 @@ func (db *MySQL) GetTableColumns(_, table string) (results [][]string, err error return } -func (db *MySQL) GetConstraints(table string) (results [][]string, err error) { - table = db.formatTableName(table) +func (db *MySQL) GetConstraints(database, table string) (results [][]string, err error) { + if database == "" { + return nil, errors.New("database name is required") + } + + if table == "" { + return nil, errors.New("table name is required") + } - splitTableString := strings.Split(table, ".") - database := splitTableString[0] - tableName := splitTableString[1] + query := "SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME FROM information_schema.KEY_COLUMN_USAGE where TABLE_SCHEMA = ? AND TABLE_NAME = ?" - rows, err := db.Connection.Query(fmt.Sprintf("SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME FROM information_schema.KEY_COLUMN_USAGE where TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s'", database, tableName)) + rows, err := db.Connection.Query(query, database, table) if err != nil { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -163,17 +205,27 @@ func (db *MySQL) GetConstraints(table string) (results [][]string, err error) { return } -func (db *MySQL) GetForeignKeys(table string) (results [][]string, err error) { - table = db.formatTableName(table) - splitTableString := strings.Split(table, ".") - database := splitTableString[0] - tableName := splitTableString[1] +func (db *MySQL) GetForeignKeys(database, table string) (results [][]string, err error) { + if database == "" { + return nil, errors.New("database name is required") + } + + if table == "" { + return nil, errors.New("table name is required") + } + + query := "SELECT TABLE_NAME, COLUMN_NAME, CONSTRAINT_NAME, REFERENCED_COLUMN_NAME, REFERENCED_TABLE_NAME FROM information_schema.KEY_COLUMN_USAGE where REFERENCED_TABLE_SCHEMA = ? AND REFERENCED_TABLE_NAME = ?" - rows, err := db.Connection.Query(fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, CONSTRAINT_NAME, REFERENCED_COLUMN_NAME, REFERENCED_TABLE_NAME FROM information_schema.KEY_COLUMN_USAGE where REFERENCED_TABLE_SCHEMA = '%s' AND REFERENCED_TABLE_NAME = '%s'", database, tableName)) + rows, err := db.Connection.Query(query, database, table) if err != nil { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -205,12 +257,28 @@ func (db *MySQL) GetForeignKeys(table string) (results [][]string, err error) { return } -func (db *MySQL) GetIndexes(table string) (results [][]string, err error) { - table = db.formatTableName(table) - rows, err := db.Connection.Query("SHOW INDEX FROM " + table) +func (db *MySQL) GetIndexes(database, table string) (results [][]string, err error) { + if database == "" { + return nil, errors.New("database name is required") + } + + if table == "" { + return nil, errors.New("table name is required") + } + + query := "SHOW INDEX FROM " + query += db.formatTableName(database, table) + + rows, err := db.Connection.Query(query) if err != nil { return nil, err } + + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -242,46 +310,58 @@ func (db *MySQL) GetIndexes(table string) (results [][]string, err error) { return } -func (db *MySQL) GetRecords(table, where, sort string, offset, limit int) (paginatedResults [][]string, totalRecords int, err error) { - table = db.formatTableName(table) - defaultLimit := 300 +func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit int) (paginatedResults [][]string, totalRecords int, err error) { + if table == "" { + return nil, 0, errors.New("table name is required") + } - isPaginationEnabled := offset >= 0 && limit >= 0 + if database == "" { + return nil, 0, errors.New("database name is required") + } - if limit != 0 { - defaultLimit = limit + if limit == 0 { + limit = DefaultRowLimit } - query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d,%d", table, offset, defaultLimit) + query := "SELECT * FROM " + query += db.formatTableName(database, table) if where != "" { - query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d,%d", table, where, offset, defaultLimit) + query += fmt.Sprintf(" %s", where) } if sort != "" { - query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d,%d", table, where, sort, offset, defaultLimit) + query += fmt.Sprintf(" ORDER BY %s", sort) } - paginatedRows, err := db.Connection.Query(query) + query += " LIMIT ?, ?" + + paginatedRows, err := db.Connection.Query(query, offset, limit) if err != nil { return nil, 0, err } - if isPaginationEnabled { - queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", table, where) + rowsErr := paginatedRows.Err() - rows := db.Connection.QueryRow(queryWithoutLimit) + if rowsErr != nil { + return nil, 0, rowsErr + } - if err != nil { - return nil, 0, err - } + defer paginatedRows.Close() - err = rows.Scan(&totalRecords) - if err != nil { - return nil, 0, err - } + countQuery := "SELECT COUNT(*) FROM " + countQuery += fmt.Sprintf("`%s`.", database) + countQuery += fmt.Sprintf("`%s`", table) - defer paginatedRows.Close() + rows := db.Connection.QueryRow(countQuery) + + if err != nil { + return nil, 0, err + } + + err = rows.Scan(&totalRecords) + if err != nil { + return nil, 0, err } columns, err := paginatedRows.Columns() @@ -320,6 +400,11 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -352,20 +437,21 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) { return } -// TODO: Rewrites this logic to use the primary key instead of the id -func (db *MySQL) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) error { - table = db.formatTableName(table) - query := fmt.Sprintf("UPDATE %s SET %s = \"%s\" WHERE %s = \"%s\"", table, column, value, primaryKeyColumnName, primaryKeyValue) - _, err := db.Connection.Exec(query) +func (db *MySQL) UpdateRecord(database, table, column, value, primaryKeyColumnName, primaryKeyValue string) error { + query := "UPDATE " + query += db.formatTableName(database, table) + query += fmt.Sprintf(" SET %s = ? WHERE %s = ?", column, primaryKeyColumnName) + + _, err := db.Connection.Exec(query, value, primaryKeyValue) return err } -// TODO: Rewrites this logic to use the primary key instead of the id -func (db *MySQL) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) error { - table = db.formatTableName(table) - query := fmt.Sprintf("DELETE FROM %s WHERE %s = \"%s\"", table, primaryKeyColumnName, primaryKeyValue) - _, err := db.Connection.Exec(query) +func (db *MySQL) DeleteRecord(database, table, primaryKeyColumnName, primaryKeyValue string) error { + query := "DELETE FROM " + query += db.formatTableName(database, table) + query += fmt.Sprintf(" WHERE %s = ?", primaryKeyColumnName) + _, err := db.Connection.Exec(query, primaryKeyValue) return err } @@ -384,94 +470,104 @@ func (db *MySQL) ExecuteDMLStatement(query string) (result string, err error) { return fmt.Sprintf("%d rows affected", rowsAffected), nil } -func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []models.DbInsert) (err error) { - queries := make([]string, 0, len(changes)+len(inserts)) +func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { + var query []models.Query - // This will hold grouped changes by their RowId and Table - groupedUpdated := make(map[string][]models.DbDmlChange) - groupedDeletes := make([]models.DbDmlChange, 0, len(changes)) - - // Group changes by RowId and Table for _, change := range changes { - switch change.Type { - case "UPDATE": - key := fmt.Sprintf("%s|%s|%s", change.Table, change.PrimaryKeyColumnName, change.PrimaryKeyValue) - groupedUpdated[key] = append(groupedUpdated[key], change) - case "DELETE": - groupedDeletes = append(groupedDeletes, change) + columnNames := []string{} + values := []interface{}{} + valuesPlaceholder := []string{} + + for _, cell := range change.Values { + switch cell.Type { + case models.Empty, models.Null, models.String: + columnNames = append(columnNames, cell.Column) + valuesPlaceholder = append(valuesPlaceholder, "?") + } } - } - - // Combine individual changes to SQL statements - for key, changes := range groupedUpdated { - columns := []string{} - // Split key into table and rowId - splitted := strings.Split(key, "|") - table := db.formatTableName(splitted[0]) - primaryKeyColumnName := splitted[1] - primaryKeyValue := splitted[2] - - for _, change := range changes { - columns = append(columns, fmt.Sprintf("%s='%s'", change.Column, change.Value)) + for _, cell := range change.Values { + switch cell.Type { + case models.Empty: + values = append(values, "") + case models.Null: + values = append(values, sql.NullString{}) + case models.String: + values = append(values, cell.Value) + case models.Default: + break + } } - // Merge all column updates - updateClause := strings.Join(columns, ", ") - - query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = '%s';", table, updateClause, primaryKeyColumnName, primaryKeyValue) - - queries = append(queries, query) - } + switch change.Type { + case models.DmlInsertType: + queryStr := "INSERT INTO " + queryStr += db.formatTableName(change.Database, change.Table) + queryStr += fmt.Sprintf(" (%s) VALUES (%s)", strings.Join(columnNames, ", "), strings.Join(valuesPlaceholder, ", ")) + + newQuery := models.Query{ + Query: queryStr, + Args: values, + } - for _, delete := range groupedDeletes { - statementType := "" - query := "" + query = append(query, newQuery) + case models.DmlUpdateType: + queryStr := "UPDATE " + queryStr += db.formatTableName(change.Database, change.Table) + + for i, column := range columnNames { + if i == 0 { + queryStr += fmt.Sprintf(" SET `%s` = ?", column) + } else { + queryStr += fmt.Sprintf(", `%s` = ?", column) + } + } - statementType = "DELETE FROM" + args := make([]interface{}, len(values)) - query = fmt.Sprintf("%s %s WHERE %s = \"%s\"", statementType, db.formatTableName(delete.Table), delete.PrimaryKeyColumnName, delete.PrimaryKeyValue) + copy(args, values) - if query != "" { - queries = append(queries, query) - } - } + queryStr += fmt.Sprintf(" WHERE %s = ?", change.PrimaryKeyColumnName) + args = append(args, change.PrimaryKeyValue) - for _, insert := range inserts { - values := make([]string, 0, len(insert.Values)) + newQuery := models.Query{ + Query: queryStr, + Args: args, + } - for _, value := range insert.Values { - _, err := strconv.ParseFloat(value, 64) + query = append(query, newQuery) + case models.DmlDeleteType: + queryStr := "DELETE FROM " + queryStr += db.formatTableName(change.Database, change.Table) + queryStr += fmt.Sprintf(" WHERE %s = ?", change.PrimaryKeyColumnName) - if strings.ToLower(value) != "default" && err != nil { - values = append(values, fmt.Sprintf("\"%s\"", value)) - } else { - values = append(values, value) + newQuery := models.Query{ + Query: queryStr, + Args: []interface{}{change.PrimaryKeyValue}, } - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", db.formatTableName(insert.Table), strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) - queries = append(queries, query) + query = append(query, newQuery) + } } - tx, err := db.Connection.Begin() + trx, err := db.Connection.Begin() if err != nil { return err } - for _, query := range queries { - - _, err = tx.Exec(query) + for _, query := range query { + logger.Info(query.Query, map[string]any{"args": query.Args}) + _, err := trx.Exec(query.Query, query.Args...) if err != nil { - return errors.Join(err, tx.Rollback()) + return err } } - err = tx.Commit() + err = trx.Commit() if err != nil { return err } + return nil } @@ -483,17 +579,6 @@ func (db *MySQL) GetProvider() string { return db.Provider } -func (db *MySQL) formatTableName(tableName string) string { - splittedTableName := strings.Split(tableName, ".") - - if len(splittedTableName) == 1 { - return tableName - } - - database := splittedTableName[0] - table := splittedTableName[1] - - formattedTableName := fmt.Sprintf("`%s`.`%s`", database, table) - - return formattedTableName +func (db *MySQL) formatTableName(database, table string) string { + return fmt.Sprintf("`%s`.`%s`", database, table) } diff --git a/drivers/postgres.go b/drivers/postgres.go index 0a0cb1c..8ca0597 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -4,13 +4,13 @@ import ( "database/sql" "errors" "fmt" - "strconv" "strings" // import postgresql driver _ "github.com/lib/pq" "github.com/xo/dburl" + "github.com/jorgerojas26/lazysql/helpers/logger" "github.com/jorgerojas26/lazysql/models" ) @@ -26,6 +26,10 @@ const ( defaultPort = "5432" ) +func (db *Postgres) TestConnection(urlstr string) error { + return db.Connect(urlstr) +} + func (db *Postgres) Connect(urlstr string) (err error) { db.SetProvider("postgres") @@ -58,16 +62,21 @@ func (db *Postgres) Connect(urlstr string) (err error) { return nil } -func (db *Postgres) TestConnection(urlstr string) error { - return db.Connect(urlstr) -} - func (db *Postgres) GetDatabases() (databases []string, err error) { rows, err := db.Connection.Query("SELECT datname FROM pg_database;") if err != nil { return nil, err } + defer rows.Close() + + rowsErr := rows.Err() + + if rowsErr != nil { + err = rowsErr + return nil, err + } + for rows.Next() { var database string err := rows.Scan(&database) @@ -83,83 +92,163 @@ func (db *Postgres) GetDatabases() (databases []string, err error) { func (db *Postgres) GetTables(database string) (tables map[string][]string, err error) { tables = make(map[string][]string) - switchDatabase := false + logger.Info("GetTables", map[string]any{"database": database}) + + if database == "" { + return nil, errors.New("database name is required") + } if database != db.CurrentDatabase { err = db.SwitchDatabase(database) if err != nil { return nil, err } - switchDatabase = true } - rows, err := db.Connection.Query(fmt.Sprintf("SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = '%s'", database)) - if err != nil { - if switchDatabase { - err = db.SwitchDatabase(db.PreviousDatabase) - if err != nil { - return nil, err - } + defer func() { + if r := recover(); r != nil { + _ = db.SwitchDatabase(db.PreviousDatabase) } - return tables, nil - } + }() - for rows.Next() { - var tableName string - var tableSchema string + query := "SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = $1" + rows, err := db.Connection.Query(query, database) - err = rows.Scan(&tableName, &tableSchema) - if err != nil { - return nil, err + if rows != nil { + rowsErr := rows.Err() + + if rowsErr != nil { + err = rowsErr } - tables[tableSchema] = append(tables[tableSchema], tableName) + defer rows.Close() + + for rows.Next() { + var tableName string + var tableSchema string + + err = rows.Scan(&tableName, &tableSchema) + + tables[tableSchema] = append(tables[tableSchema], tableName) + + } } + if err != nil { + return nil, err + } + return tables, nil } func (db *Postgres) GetTableColumns(database, table string) (results [][]string, err error) { - tableSchema := strings.Split(table, ".")[0] - tableName := strings.Split(table, ".")[1] - rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableSchema, tableName)) - if err != nil { - return nil, err + if database == "" { + return nil, errors.New("database name is required") } - defer rows.Close() - columns, err := rows.Columns() - if err != nil { - return nil, err + if table == "" { + return nil, errors.New("table name is required") } - results = append(results, columns) + splitTableString := strings.Split(table, ".") - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } + if len(splitTableString) == 1 { + return nil, errors.New("table must be in the format schema.table") + } - err = rows.Scan(rowValues...) + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) if err != nil { return nil, err } + } - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) + defer func() { + if r := recover(); r != nil { + _ = db.SwitchDatabase(db.PreviousDatabase) } + }() - results = append(results, row) + tableSchema := splitTableString[0] + tableName := splitTableString[1] + + query := "SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = $1 AND table_schema = $2 AND table_name = $3 ORDER by ordinal_position" + + rows, err := db.Connection.Query(query, database, tableSchema, tableName) + + if rows != nil { + + rowsErr := rows.Err() + + if rowsErr != nil { + err = rowsErr + } + + defer rows.Close() + + columns, columnsError := rows.Columns() + + if columnsError != nil { + err = columnsError + } + + results = append(results, columns) + + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + + for i := range columns { + rowValues[i] = new(sql.RawBytes) + } + + err = rows.Scan(rowValues...) + + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) + } + + results = append(results, row) + } + + } + + if err != nil { + return nil, err } return } -func (db *Postgres) GetConstraints(table string) (constraints [][]string, err error) { +func (db *Postgres) GetConstraints(database, table string) (constraints [][]string, err error) { + if database == "" { + return nil, errors.New("database name is required") + } + + if table == "" { + return nil, errors.New("table name is required") + } + splitTableString := strings.Split(table, ".") + + if len(splitTableString) == 1 { + return nil, errors.New("table must be in the format schema.table") + } + + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) + if err != nil { + return nil, err + } + } + + defer func() { + if r := recover(); r != nil { + _ = db.SwitchDatabase(db.PreviousDatabase) + } + }() + tableSchema := splitTableString[0] tableName := splitTableString[1] @@ -179,43 +268,77 @@ func (db *Postgres) GetConstraints(table string) (constraints [][]string, err er AND tc.table_schema = '%s' AND tc.table_name = '%s' `, tableSchema, tableName)) - if err != nil { - return nil, err - } - defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, err - } + if rows != nil { - constraints = append(constraints, columns) + rowsErr := rows.Err() - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) + if rowsErr != nil { + err = rowsErr } - err = rows.Scan(rowValues...) - if err != nil { - return nil, err + defer rows.Close() + + columns, columnsError := rows.Columns() + + if columnsError != nil { + err = columnsError } - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) + constraints = append(constraints, columns) + + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) + } + + err = rows.Scan(rowValues...) + + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) + } + + constraints = append(constraints, row) } + } - constraints = append(constraints, row) + if err != nil { + return nil, err } return } -func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, err error) { +func (db *Postgres) GetForeignKeys(database, table string) (foreignKeys [][]string, err error) { + if database == "" { + return nil, errors.New("database name is required") + } + + if table == "" { + return nil, errors.New("table name is required") + } + splitTableString := strings.Split(table, ".") + + if len(splitTableString) == 1 { + return nil, errors.New("table must be in the format schema.table") + } + + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) + if err != nil { + return nil, err + } + } + + defer func() { + if r := recover(); r != nil { + _ = db.SwitchDatabase(db.PreviousDatabase) + } + }() + tableSchema := splitTableString[0] tableName := splitTableString[1] @@ -236,43 +359,77 @@ func (db *Postgres) GetForeignKeys(table string) (foreignKeys [][]string, err er AND tc.table_schema = '%s' AND tc.table_name = '%s' `, tableSchema, tableName)) - if err != nil { - return nil, err - } - defer rows.Close() + if rows != nil { - columns, err := rows.Columns() - if err != nil { - return nil, err - } - - foreignKeys = append(foreignKeys, columns) + rowsErr := rows.Err() - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) + if rowsErr != nil { + err = rowsErr } - err = rows.Scan(rowValues...) - if err != nil { - return nil, err + defer rows.Close() + + columns, columnsError := rows.Columns() + + if columnsError != nil { + err = columnsError } - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) + foreignKeys = append(foreignKeys, columns) + + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) + } + + err = rows.Scan(rowValues...) + + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) + } + + foreignKeys = append(foreignKeys, row) } + } - foreignKeys = append(foreignKeys, row) + if err != nil { + return nil, err } return } -func (db *Postgres) GetIndexes(table string) (indexes [][]string, err error) { +func (db *Postgres) GetIndexes(database, table string) (indexes [][]string, err error) { + if database == "" { + return nil, errors.New("database name is required") + } + + if table == "" { + return nil, errors.New("table name is required") + } + splitTableString := strings.Split(table, ".") + + if len(splitTableString) == 1 { + return nil, errors.New("table must be in the format schema.table") + } + + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) + if err != nil { + return nil, err + } + } + + defer func() { + if r := recover(); r != nil { + _ = db.SwitchDatabase(db.PreviousDatabase) + } + }() + tableSchema := splitTableString[0] tableName := splitTableString[1] @@ -302,123 +459,265 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, err error) { t.relname, i.relname `, tableSchema, tableName)) - if err != nil { - return nil, err - } - defer rows.Close() - columns, err := rows.Columns() - if err != nil { - return nil, err - } + if rows != nil { - indexes = append(indexes, columns) + rowsErr := rows.Err() - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) + if rowsErr != nil { + err = rowsErr } - err = rows.Scan(rowValues...) - if err != nil { - return nil, err + defer rows.Close() + + columns, columnsError := rows.Columns() + + if columnsError != nil { + err = columnsError } - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) + indexes = append(indexes, columns) + + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) + } + + err = rows.Scan(rowValues...) + + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) + } + + indexes = append(indexes, row) } + } - indexes = append(indexes, row) + if err != nil { + return nil, err } return } -func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (records [][]string, totalRecords int, err error) { - table = db.formatTableName(table) - defaultLimit := 300 - isPaginationEnabled := offset >= 0 && limit >= 0 +func (db *Postgres) GetRecords(database, table, where, sort string, offset, limit int) (records [][]string, totalRecords int, err error) { + if database == "" { + return nil, 0, errors.New("database name is required") + } + + if table == "" { + return nil, 0, errors.New("table name is required") + } + + splitTableString := strings.Split(table, ".") + + if len(splitTableString) == 1 { + return nil, 0, errors.New("table must be in the format schema.table") + } + + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) + if err != nil { + return nil, 0, err + } + } + + defer func() { + if r := recover(); r != nil { + if database != db.PreviousDatabase { + _ = db.SwitchDatabase(db.PreviousDatabase) + } + } + }() + + tableSchema := splitTableString[0] + tableName := splitTableString[1] - if limit != 0 { - defaultLimit = limit + formattedTableName := db.formatTableName(tableSchema, tableName) + + if limit == 0 { + limit = DefaultRowLimit } - query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d OFFSET %d", table, defaultLimit, offset) + query := "SELECT * FROM " + query += formattedTableName if where != "" { - query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d OFFSET %d", table, where, defaultLimit, offset) + query += fmt.Sprintf(" %s", where) } if sort != "" { - query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d OFFSET %d", table, where, sort, defaultLimit, offset) + query += fmt.Sprintf(" ORDER BY %s", sort) } - paginatedRows, err := db.Connection.Query(query) - if err != nil { - return nil, 0, err - } + query += " LIMIT $1 OFFSET $2" - if isPaginationEnabled { - queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", table, where) + paginatedRows, err := db.Connection.Query(query, limit, offset) - rows := db.Connection.QueryRow(queryWithoutLimit) + if paginatedRows != nil { - if err != nil { - return nil, 0, err + rowsErr := paginatedRows.Err() + + defer paginatedRows.Close() + + if rowsErr != nil { + err = rowsErr + } + + countQuery := "SELECT COUNT(*) FROM " + countQuery += formattedTableName + + rows := db.Connection.QueryRow(countQuery) + + rowsErr = rows.Err() + + if rowsErr != nil { + err = rowsErr } err = rows.Scan(&totalRecords) - if err != nil { - return nil, 0, err + + columns, columnsError := paginatedRows.Columns() + + if columnsError != nil { + err = columnsError } - defer paginatedRows.Close() + records = append(records, columns) + + for paginatedRows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) + } + + err = paginatedRows.Scan(rowValues...) + + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) + } + + records = append(records, row) + + } } - columns, err := paginatedRows.Columns() if err != nil { return nil, 0, err } - records = append(records, columns) + return +} - for paginatedRows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } +func (db *Postgres) UpdateRecord(database, table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) { + if database == "" { + return errors.New("database name is required") + } + + if table == "" { + return errors.New("table name is required") + } + + if column == "" { + return errors.New("column name is required") + } + + if value == "" { + return errors.New("value is required") + } + + if primaryKeyColumnName == "" { + return errors.New("primary key column name is required") + } + + if primaryKeyValue == "" { + return errors.New("primary key value is required") + } + + splitTableString := strings.Split(table, ".") + + if len(splitTableString) == 1 { + return errors.New("table must be in the format schema.table") + } + + switchDatabaseOnError := false - err = paginatedRows.Scan(rowValues...) + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) if err != nil { - return nil, 0, err + return err } + switchDatabaseOnError = true + } - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } + tableSchema := splitTableString[0] + tableName := splitTableString[1] - records = append(records, row) + formattedTableName := db.formatTableName(tableSchema, tableName) - } + query := "UPDATE " + query += formattedTableName + query += fmt.Sprintf(" SET \"%s\" = $1 WHERE \"%s\" = $2", column, primaryKeyColumnName) - return -} + _, err = db.Connection.Exec(query, value, primaryKeyValue) -func (db *Postgres) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) { - table = db.formatTableName(table) - query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE \"%s\" = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue) - _, err = db.Connection.Exec(query) + if err != nil && switchDatabaseOnError { + err = db.SwitchDatabase(db.PreviousDatabase) + } return err } -func (db *Postgres) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) (err error) { - table = db.formatTableName(table) - query := fmt.Sprintf("DELETE FROM %s WHERE \"%s\" = '%s'", table, primaryKeyColumnName, primaryKeyValue) - _, err = db.Connection.Exec(query) +func (db *Postgres) DeleteRecord(database, table, primaryKeyColumnName, primaryKeyValue string) (err error) { + if database == "" { + return errors.New("database name is required") + } + + if table == "" { + return errors.New("table name is required") + } + + if primaryKeyColumnName == "" { + return errors.New("primary key column name is required") + } + + if primaryKeyValue == "" { + return errors.New("primary key value is required") + } + + splitTableString := strings.Split(table, ".") + + if len(splitTableString) == 1 { + return errors.New("table must be in the format schema.table") + } + + switchDatabaseOnError := false + + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) + if err != nil { + return err + } + switchDatabaseOnError = true + } + + tableSchema := splitTableString[0] + tableName := splitTableString[1] + + formattedTableName := db.formatTableName(tableSchema, tableName) + + query := "DELETE FROM " + query += formattedTableName + query += fmt.Sprintf(" WHERE \"%s\" = $1", primaryKeyColumnName) + + _, err = db.Connection.Exec(query, primaryKeyValue) + + if err != nil && switchDatabaseOnError { + err = db.SwitchDatabase(db.PreviousDatabase) + } return err } @@ -444,6 +743,12 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { defer rows.Close() + rowsErr := rows.Err() + + if rowsErr != nil { + err = rowsErr + } + columns, err := rows.Columns() if err != nil { return nil, err @@ -474,87 +779,108 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { return } -func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts []models.DbInsert) (err error) { - queries := make([]string, 0, len(changes)+len(inserts)) - - // This will hold grouped changes by their RowId and Table - groupedUpdated := make(map[string][]models.DbDmlChange) - groupedDeletes := make([]models.DbDmlChange, 0, len(changes)) +func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { + var query []models.Query - // Group changes by RowId and Table for _, change := range changes { - if change.Type == "UPDATE" { - key := fmt.Sprintf("%s|%s|%s", db.formatTableName(change.Table), change.PrimaryKeyColumnName, change.PrimaryKeyValue) - groupedUpdated[key] = append(groupedUpdated[key], change) - } else if change.Type == "DELETE" { - groupedDeletes = append(groupedDeletes, change) + columnNames := []string{} + values := []interface{}{} + valuesPlaceholder := []string{} + placeholderIndex := 1 + + for _, cell := range change.Values { + switch cell.Type { + case models.Empty, models.Null, models.String: + columnNames = append(columnNames, cell.Column) + valuesPlaceholder = append(valuesPlaceholder, fmt.Sprintf("$%d", placeholderIndex)) + placeholderIndex++ + } } - } - // Combine individual changes to SQL statements - for key, changes := range groupedUpdated { - columns := []string{} + for _, cell := range change.Values { + switch cell.Type { + case models.Empty: + values = append(values, "") + case models.Null: + values = append(values, sql.NullString{}) + case models.String: + values = append(values, cell.Value) + case models.Default: + break + } + } - // Split key into table and rowId - splitted := strings.Split(key, "|") - table := db.formatTableName(splitted[0]) - PrimaryKeyColumnName := splitted[1] - primaryKeyValue := splitted[2] + splitTableString := strings.Split(change.Table, ".") - for _, change := range changes { - columns = append(columns, fmt.Sprintf("%s='%s'", change.Column, change.Value)) - } + tableSchema := splitTableString[0] + tableName := splitTableString[1] - // Merge all column updates - updateClause := strings.Join(columns, ", ") + formattedTableName := db.formatTableName(tableSchema, tableName) - query := fmt.Sprintf("UPDATE %s SET %s WHERE \"%s\" = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue) + switch change.Type { - queries = append(queries, query) - } + case models.DmlInsertType: - for _, del := range groupedDeletes { - statementType := "" - query := "" + queryStr := "INSERT INTO " + formattedTableName + queryStr += fmt.Sprintf(" (%s) VALUES (%s)", strings.Join(columnNames, ", "), strings.Join(valuesPlaceholder, ", ")) - statementType = "DELETE FROM" - query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, db.formatTableName(del.Table), del.PrimaryKeyColumnName, del.PrimaryKeyValue) + newQuery := models.Query{ + Query: queryStr, + Args: values, + } - if query != "" { - queries = append(queries, query) - } - } + query = append(query, newQuery) + case models.DmlUpdateType: + queryStr := "UPDATE " + formattedTableName + + for i, column := range columnNames { + if i == 0 { + queryStr += fmt.Sprintf(" SET \"%s\" = $1", column) + } else { + queryStr += fmt.Sprintf(", \"%s\" = $%d", column, i+1) + } + } - for _, insert := range inserts { - values := make([]string, 0, len(insert.Values)) + args := make([]interface{}, len(values)) - for _, value := range insert.Values { - _, err := strconv.ParseFloat(value, 64) + copy(args, values) - if strings.ToLower(value) != "default" && err != nil { - values = append(values, fmt.Sprintf("'%s'", value)) - } else { - values = append(values, value) + queryStr += fmt.Sprintf(" WHERE \"%s\" = $%d", change.PrimaryKeyColumnName, len(columnNames)+1) + args = append(args, change.PrimaryKeyValue) + + newQuery := models.Query{ + Query: queryStr, + Args: args, } - } - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", db.formatTableName(insert.Table), strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) - queries = append(queries, query) + query = append(query, newQuery) + case models.DmlDeleteType: + queryStr := "DELETE FROM " + formattedTableName + queryStr += fmt.Sprintf(" WHERE %s = $1", change.PrimaryKeyColumnName) + + newQuery := models.Query{ + Query: queryStr, + Args: []interface{}{change.PrimaryKeyValue}, + } + + query = append(query, newQuery) + } } - tx, err := db.Connection.Begin() + trx, err := db.Connection.Begin() if err != nil { return err } - for _, query := range queries { - _, err = tx.Exec(query) + for _, query := range query { + logger.Info(query.Query, map[string]any{"args": query.Args}) + _, err := trx.Exec(query.Query, query.Args...) if err != nil { - return errors.Join(err, tx.Rollback()) + return err } } - err = tx.Commit() + err = trx.Commit() if err != nil { return err } @@ -603,17 +929,6 @@ func (db *Postgres) SwitchDatabase(database string) error { return nil } -func (db *Postgres) formatTableName(table string) string { - splittedTableName := strings.Split(table, ".") - - if len(splittedTableName) == 1 { - return table - } - - schema := splittedTableName[0] - tableName := splittedTableName[1] - - formattedTableName := fmt.Sprintf("\"%s\".\"%s\"", schema, tableName) - - return formattedTableName +func (db *Postgres) formatTableName(database, table string) string { + return fmt.Sprintf("\"%s\".\"%s\"", database, table) } diff --git a/drivers/sqlite.go b/drivers/sqlite.go index 0493b01..a219898 100644 --- a/drivers/sqlite.go +++ b/drivers/sqlite.go @@ -4,12 +4,12 @@ import ( "database/sql" "errors" "fmt" - "strconv" "strings" // import sqlite driver _ "modernc.org/sqlite" + "github.com/jorgerojas26/lazysql/helpers/logger" "github.com/jorgerojas26/lazysql/models" ) @@ -46,6 +46,13 @@ func (db *SQLite) GetDatabases() ([]string, error) { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + + defer rows.Close() + for rows.Next() { var database string err := rows.Scan(&database) @@ -63,14 +70,24 @@ func (db *SQLite) GetDatabases() ([]string, error) { } func (db *SQLite) GetTables(database string) (map[string][]string, error) { - rows, err := db.Connection.Query("SELECT name FROM sqlite_master WHERE type='table'") - - tables := make(map[string][]string) + if database == "" { + return nil, errors.New("database name is required") + } + rows, err := db.Connection.Query("SELECT name FROM sqlite_master WHERE type='table'") if err != nil { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + + defer rows.Close() + + tables := make(map[string][]string) + for rows.Next() { var table string err = rows.Scan(&table) @@ -85,10 +102,20 @@ func (db *SQLite) GetTables(database string) (map[string][]string, error) { } func (db *SQLite) GetTableColumns(_, table string) (results [][]string, err error) { - rows, err := db.Connection.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) + if table == "" { + return nil, errors.New("table name is required") + } + + rows, err := db.Connection.Query(fmt.Sprintf("PRAGMA table_info(%s)", db.formatTableName(table))) if err != nil { return nil, err } + + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -125,12 +152,24 @@ func (db *SQLite) GetTableColumns(_, table string) (results [][]string, err erro return } -func (db *SQLite) GetConstraints(table string) (results [][]string, err error) { - rows, err := db.Connection.Query("SELECT sql FROM sqlite_master WHERE type='table' AND name = '" + table + "'") +func (db *SQLite) GetConstraints(_, table string) (results [][]string, err error) { + if table == "" { + return nil, errors.New("table name is required") + } + + query := "SELECT sql FROM sqlite_master " + query += "WHERE type='table' AND name = ?" + + rows, err := db.Connection.Query(query, table) if err != nil { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -166,12 +205,21 @@ func (db *SQLite) GetConstraints(table string) (results [][]string, err error) { return } -func (db *SQLite) GetForeignKeys(table string) (results [][]string, err error) { +func (db *SQLite) GetForeignKeys(_, table string) (results [][]string, err error) { + if table == "" { + return nil, errors.New("table name is required") + } + rows, err := db.Connection.Query("PRAGMA foreign_key_list(" + table + ")") if err != nil { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -207,11 +255,21 @@ func (db *SQLite) GetForeignKeys(table string) (results [][]string, err error) { return } -func (db *SQLite) GetIndexes(table string) (results [][]string, err error) { +func (db *SQLite) GetIndexes(_, table string) (results [][]string, err error) { + if table == "" { + return nil, errors.New("table name is required") + } + rows, err := db.Connection.Query("PRAGMA index_list(" + table + ")") if err != nil { return nil, err } + + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -247,44 +305,53 @@ func (db *SQLite) GetIndexes(table string) (results [][]string, err error) { return } -func (db *SQLite) GetRecords(table, where, sort string, offset, limit int) (paginatedResults [][]string, totalRecords int, err error) { - defaultLimit := 300 - - isPaginationEnabled := offset >= 0 && limit >= 0 +func (db *SQLite) GetRecords(_, table, where, sort string, offset, limit int) (paginatedResults [][]string, totalRecords int, err error) { + if table == "" { + return nil, 0, errors.New("table name is required") + } - if limit != 0 { - defaultLimit = limit + if limit == 0 { + limit = DefaultRowLimit } - query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d,%d", table, offset, defaultLimit) + query := "SELECT * FROM " + query += db.formatTableName(table) if where != "" { - query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d,%d", table, where, offset, defaultLimit) + query += fmt.Sprintf(" %s", where) } if sort != "" { - query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d,%d", table, where, sort, offset, defaultLimit) + query += fmt.Sprintf(" ORDER BY %s", sort) } - paginatedRows, err := db.Connection.Query(query) + query += " LIMIT ?, ?" + + paginatedRows, err := db.Connection.Query(query, offset, limit) if err != nil { return nil, 0, err } - if isPaginationEnabled { - queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", table, where) + rowsErr := paginatedRows.Err() - rows := db.Connection.QueryRow(queryWithoutLimit) - if err != nil { - return nil, 0, err - } + if rowsErr != nil { + return nil, 0, rowsErr + } - err = rows.Scan(&totalRecords) - if err != nil { - return nil, 0, err - } + defer paginatedRows.Close() + + countQuery := "SELECT COUNT(*) FROM " + countQuery += db.formatTableName(table) + + rows := db.Connection.QueryRow(countQuery) + + if err != nil { + return nil, 0, err + } - defer paginatedRows.Close() + err = rows.Scan(&totalRecords) + if err != nil { + return nil, 0, err } columns, err := paginatedRows.Columns() @@ -307,11 +374,7 @@ func (db *SQLite) GetRecords(table, where, sort string, offset, limit int) (pagi var row []string for _, col := range rowValues { - if col == nil { - row = append(row, "NULL") - } else { - row = append(row, string(*col.(*sql.RawBytes))) - } + row = append(row, string(*col.(*sql.RawBytes))) } paginatedResults = append(paginatedResults, row) @@ -327,6 +390,11 @@ func (db *SQLite) ExecuteQuery(query string) (results [][]string, err error) { return nil, err } + rowsErr := rows.Err() + if rowsErr != nil { + return nil, rowsErr + } + defer rows.Close() columns, err := rows.Columns() @@ -363,15 +431,53 @@ func (db *SQLite) ExecuteQuery(query string) (results [][]string, err error) { return } -func (db *SQLite) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) error { - query := fmt.Sprintf("UPDATE %s SET %s = \"%s\" WHERE %s = %s;", table, column, value, primaryKeyColumnName, primaryKeyValue) - _, err := db.Connection.Exec(query) +func (db *SQLite) UpdateRecord(_, table, column, value, primaryKeyColumnName, primaryKeyValue string) error { + if table == "" { + return errors.New("table name is required") + } + + if column == "" { + return errors.New("column name is required") + } + + if value == "" { + return errors.New("value is required") + } + + if primaryKeyColumnName == "" { + return errors.New("primary key column name is required") + } + + if primaryKeyValue == "" { + return errors.New("primary key value is required") + } + + query := "UPDATE " + query += db.formatTableName(table) + query += fmt.Sprintf(" SET %s = ? WHERE %s = ?", column, primaryKeyColumnName) + + _, err := db.Connection.Exec(query, value, primaryKeyValue) return err } -func (db *SQLite) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) error { - query := fmt.Sprintf("DELETE FROM %s WHERE %s = \"%s\"", table, primaryKeyColumnName, primaryKeyValue) +func (db *SQLite) DeleteRecord(_, table, primaryKeyColumnName, primaryKeyValue string) error { + if table == "" { + return errors.New("table name is required") + } + + if primaryKeyColumnName == "" { + return errors.New("primary key column name is required") + } + + if primaryKeyValue == "" { + return errors.New("primary key value is required") + } + + query := "DELETE FROM " + query += db.formatTableName(table) + query += fmt.Sprintf(" WHERE %s = ?", primaryKeyColumnName) + _, err := db.Connection.Exec(query) return err @@ -391,94 +497,105 @@ func (db *SQLite) ExecuteDMLStatement(query string) (result string, err error) { return fmt.Sprintf("%d rows affected", rowsAffected), nil } -func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange, inserts []models.DbInsert) (err error) { - queries := make([]string, 0, len(changes)+len(inserts)) - - // This will hold grouped changes by their RowId and Table - groupedUpdated := make(map[string][]models.DbDmlChange) - groupedDeletes := make([]models.DbDmlChange, 0, len(changes)) +func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { + var query []models.Query - // Group changes by RowId and Table for _, change := range changes { - if change.Type == "UPDATE" { - key := fmt.Sprintf("%s|%s|%s", change.Table, change.PrimaryKeyColumnName, change.PrimaryKeyValue) - groupedUpdated[key] = append(groupedUpdated[key], change) - } else if change.Type == "DELETE" { - groupedDeletes = append(groupedDeletes, change) + columnNames := []string{} + values := []interface{}{} + valuesPlaceholder := []string{} + + for _, cell := range change.Values { + switch cell.Type { + case models.Empty, models.Null, models.String: + columnNames = append(columnNames, cell.Column) + valuesPlaceholder = append(valuesPlaceholder, "?") + } } - } - - // Combine individual changes to SQL statements - for key, changes := range groupedUpdated { - columns := []string{} - - // Split key into table and rowId - splitted := strings.Split(key, "|") - table := splitted[0] - primaryKeyColumnName := splitted[1] - primaryKeyValue := splitted[2] - - for _, change := range changes { - columns = append(columns, fmt.Sprintf("%s='%s'", change.Column, change.Value)) + logger.Info("Column names", map[string]any{"columnNames": columnNames}) + + for _, cell := range change.Values { + switch cell.Type { + case models.Empty: + values = append(values, "") + case models.Null: + values = append(values, sql.NullString{}) + case models.String: + values = append(values, cell.Value) + case models.Default: + break + } } - // Merge all column updates - updateClause := strings.Join(columns, ", ") - - query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = '%s';", table, updateClause, primaryKeyColumnName, primaryKeyValue) - - queries = append(queries, query) - } + switch change.Type { + case models.DmlInsertType: + queryStr := "INSERT INTO " + queryStr += db.formatTableName(change.Table) + queryStr += fmt.Sprintf(" (%s) VALUES (%s)", strings.Join(columnNames, ", "), strings.Join(valuesPlaceholder, ", ")) - for _, delete := range groupedDeletes { - statementType := "" - query := "" + newQuery := models.Query{ + Query: queryStr, + Args: values, + } - statementType = "DELETE FROM" + query = append(query, newQuery) + case models.DmlUpdateType: + queryStr := "UPDATE " + queryStr += db.formatTableName(change.Table) + + for i, column := range columnNames { + if i == 0 { + queryStr += fmt.Sprintf(" SET `%s` = ?", column) + } else { + queryStr += fmt.Sprintf(", `%s` = ?", column) + } + } - query = fmt.Sprintf("%s %s WHERE %s = \"%s\"", statementType, delete.Table, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue) + args := make([]interface{}, len(values)) - if query != "" { - queries = append(queries, query) - } - } + copy(args, values) - for _, insert := range inserts { - values := make([]string, 0, len(insert.Values)) + queryStr += fmt.Sprintf(" WHERE %s = ?", change.PrimaryKeyColumnName) + args = append(args, change.PrimaryKeyValue) - columnsToBeInserted := insert.Columns + newQuery := models.Query{ + Query: queryStr, + Args: args, + } - for _, value := range insert.Values { - _, err := strconv.ParseFloat(value, 64) + query = append(query, newQuery) + case models.DmlDeleteType: + queryStr := "DELETE FROM " + queryStr += db.formatTableName(change.Table) + queryStr += fmt.Sprintf(" WHERE %s = ?", change.PrimaryKeyColumnName) - if err != nil { - values = append(values, fmt.Sprintf("\"%s\"", value)) - } else { - values = append(values, value) + newQuery := models.Query{ + Query: queryStr, + Args: []interface{}{change.PrimaryKeyValue}, } - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", insert.Table, strings.Join(columnsToBeInserted, ", "), strings.Join(values, ", ")) - queries = append(queries, query) + query = append(query, newQuery) + } } - tx, err := db.Connection.Begin() + trx, err := db.Connection.Begin() if err != nil { return err } - for _, query := range queries { - _, err = tx.Exec(query) + for _, query := range query { + logger.Info(query.Query, map[string]any{"args": query.Args}) + _, err := trx.Exec(query.Query, query.Args...) if err != nil { - return errors.Join(err, tx.Rollback()) + return err } } - err = tx.Commit() + err = trx.Commit() if err != nil { return err } + return nil } @@ -489,3 +606,7 @@ func (db *SQLite) SetProvider(provider string) { func (db *SQLite) GetProvider() string { return db.Provider } + +func (db *SQLite) formatTableName(table string) string { + return fmt.Sprintf("`%s`", table) +} diff --git a/helpers/logger/logger.go b/helpers/logger/logger.go new file mode 100644 index 0000000..2dbe6f2 --- /dev/null +++ b/helpers/logger/logger.go @@ -0,0 +1,135 @@ +package logger + +import ( + "encoding/json" + "fmt" + "log/slog" + "os" + "strings" + "sync" + "time" +) + +type logger struct { + mu sync.Mutex + file *os.File + level slog.Level + output string +} + +type logMessage struct { + Timestamp string `json:"timestamp"` + Level string `json:"level"` + Message string `json:"message"` + Data map[string]any `json:"additional_info,omitempty"` +} + +var logInstance *logger + +func init() { + logInstance = &logger{level: slog.LevelInfo} +} + +func (l *logger) log(level slog.Level, msg string, data map[string]any) { + if level < l.level { + return + } + + logMessage := logMessage{ + Timestamp: time.Now().Format(time.RFC3339), + Level: level.String(), + Message: msg, + Data: data, + } + + logData, err := json.Marshal(logMessage) + if err != nil { + fmt.Println("Error marshaling log message:", err) + return + } + + l.mu.Lock() + defer l.mu.Unlock() + + if l.file == nil { + // maybe add another way to log, I did not want to add fmt.Println since this is a TUI app + return + } + + _, err = l.file.Write(logData) + if err != nil { + return + } + + _, err = l.file.Write([]byte("\n")) + if err != nil { + return + } +} + +func (l *logger) SetFile(filename string) error { + l.mu.Lock() + defer l.mu.Unlock() + + if l.file != nil { + err := l.file.Close() + if err != nil { + return err + } + } + + file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + + l.file = file + l.output = filename + return nil +} + +func (l *logger) SetLevel(level slog.Level) { + l.mu.Lock() + defer l.mu.Unlock() + + l.level = level +} + +func SetLevel(level slog.Level) { + logInstance.SetLevel(level) +} + +func SetFile(filename string) error { + return logInstance.SetFile(filename) +} + +func Debug(msg string, data map[string]any) { + logInstance.log(slog.LevelDebug, msg, data) +} + +func Info(msg string, data map[string]any) { + logInstance.log(slog.LevelInfo, msg, data) +} + +func Warn(msg string, data map[string]any) { + logInstance.log(slog.LevelWarn, msg, data) +} + +func Error(msg string, data map[string]any) { + logInstance.log(slog.LevelError, msg, data) +} + +func ParseLogLevel(s string) (slog.Level, error) { + switch strings.ToLower(s) { + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + default: + return slog.LevelInfo, fmt.Errorf("unknown log level %q", s) + } +} diff --git a/main.go b/main.go index 2e01b09..074efd7 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "io" "log" "os" @@ -9,14 +10,34 @@ import ( "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/components" + "github.com/jorgerojas26/lazysql/helpers/logger" ) var version = "dev" func main() { - err := mysql.SetLogger(log.New(io.Discard, "", 0)) - if err != nil { - panic(err) + rawLogLvl := flag.String("loglvl", "info", "Log level") + logFile := flag.String("logfile", "", "Log file") + flag.Parse() + + logLvl, parseError := logger.ParseLogLevel(*rawLogLvl) + if parseError != nil { + panic(parseError) + } + logger.SetLevel(logLvl) + + if *logFile != "" { + fileError := logger.SetFile(*logFile) + if fileError != nil { + panic(fileError) + } + } + + logger.Info("Starting LazySQL...", nil) + + mysqlError := mysql.SetLogger(log.New(io.Discard, "", 0)) + if mysqlError != nil { + panic(mysqlError) } // check if "version" arg is passed diff --git a/models/models.go b/models/models.go index 371fdd5..e8eb5ea 100644 --- a/models/models.go +++ b/models/models.go @@ -1,7 +1,6 @@ package models import ( - "github.com/google/uuid" "github.com/rivo/tview" ) @@ -22,22 +21,40 @@ type ConnectionPages struct { *tview.Pages } +type ( + CellValueType int8 + DmlType int8 +) + +// This is not a direct map of the database types, but rather a way to represent them in the UI. +// So the String type is a representation of the cell value in the UI table and the others are +// just a representation of the values that you can put in the database but not in the UI as a string of characters. +const ( + Empty CellValueType = iota + Null + Default + String +) + +type CellValue struct { + Type CellValueType + Column string + Value interface{} +} + +const ( + DmlUpdateType DmlType = iota + DmlDeleteType + DmlInsertType +) + type DbDmlChange struct { - Type string + Type DmlType + Database string Table string - Column string - Value string + Values []CellValue PrimaryKeyColumnName string PrimaryKeyValue string - Option int -} - -type DbInsert struct { - Table string - Columns []string - Values []string - Option int - PrimaryKeyValue uuid.UUID } type DatabaseTableColumn struct { @@ -48,3 +65,8 @@ type DatabaseTableColumn struct { Default string Extra string } + +type Query struct { + Query string + Args []interface{} +} From 6690fe8377b6e458451890f21d3971de1b63563e Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sun, 25 Aug 2024 16:53:54 -0400 Subject: [PATCH 2/3] fix: linter issues --- components/Home.go | 2 +- drivers/mysql.go | 2 -- drivers/postgres.go | 8 +++++--- drivers/sqlite.go | 2 -- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/components/Home.go b/components/Home.go index 4ceee8e..ad82d53 100644 --- a/components/Home.go +++ b/components/Home.go @@ -280,7 +280,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { App.Stop() } } else if command == commands.Save { - if (home.ListOfDbChanges != nil && len(home.ListOfDbChanges) > 0) && !table.GetIsEditing() { + if (len(home.ListOfDbChanges) > 0) && !table.GetIsEditing() { confirmationModal := NewConfirmationModal("") confirmationModal.SetDoneFunc(func(_ int, buttonLabel string) { diff --git a/drivers/mysql.go b/drivers/mysql.go index 05472b0..1c47412 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -494,8 +494,6 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) values = append(values, sql.NullString{}) case models.String: values = append(values, cell.Value) - case models.Default: - break } } diff --git a/drivers/postgres.go b/drivers/postgres.go index 8ca0597..a3fede9 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -805,8 +805,6 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err values = append(values, sql.NullString{}) case models.String: values = append(values, cell.Value) - case models.Default: - break } } @@ -921,7 +919,11 @@ func (db *Postgres) SwitchDatabase(database string) error { return err } - db.Connection.Close() + err = db.Connection.Close() + if err != nil { + return err + } + db.Connection = connection db.PreviousDatabase = db.CurrentDatabase db.CurrentDatabase = database diff --git a/drivers/sqlite.go b/drivers/sqlite.go index a219898..24fc476 100644 --- a/drivers/sqlite.go +++ b/drivers/sqlite.go @@ -522,8 +522,6 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error values = append(values, sql.NullString{}) case models.String: values = append(values, cell.Value) - case models.Default: - break } } From cfd8452244a387f5440ba74d9e90960ca8d4d964 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sun, 25 Aug 2024 17:12:43 -0400 Subject: [PATCH 3/3] add quit button to connection selection --- components/ConnectionSelection.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/components/ConnectionSelection.go b/components/ConnectionSelection.go index b90fb85..7bfd2cb 100644 --- a/components/ConnectionSelection.go +++ b/components/ConnectionSelection.go @@ -59,6 +59,9 @@ func NewConnectionSelection(connectionForm *ConnectionForm, connectionPages *mod quitButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) quitButton.SetBorder(true) + buttonsWrapper.AddItem(quitButton, 0, 1, false) + buttonsWrapper.AddItem(nil, 1, 0, false) + statusText := tview.NewTextView() statusText.SetBorderPadding(1, 1, 0, 0)