diff --git a/internal/app/app.go b/internal/app/app.go index 8f305f765..46f46e697 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -10,6 +10,7 @@ import ( "time" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/commandhistory" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" @@ -26,10 +27,11 @@ import ( ) type App struct { - Sessions session.Service - Messages message.Service - History history.Service - Permissions permission.Service + Sessions session.Service + Messages message.Service + History history.Service + CommandHistory commandhistory.Service + Permissions permission.Service CoderAgent agent.Service @@ -53,6 +55,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { sessions := session.NewService(q) messages := message.NewService(q) files := history.NewService(q, conn) + commandHistory := commandhistory.NewService(q, conn) skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests allowedTools := []string{} if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil { @@ -60,11 +63,12 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { } app := &App{ - Sessions: sessions, - Messages: messages, - History: files, - Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), - LSPClients: csync.NewMap[string, *lsp.Client](), + Sessions: sessions, + Messages: messages, + History: files, + CommandHistory: commandHistory, + Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), + LSPClients: csync.NewMap[string, *lsp.Client](), globalCtx: ctx, diff --git a/internal/commandhistory/service.go b/internal/commandhistory/service.go new file mode 100644 index 000000000..4a39a9e34 --- /dev/null +++ b/internal/commandhistory/service.go @@ -0,0 +1,136 @@ +package commandhistory + +import ( + "context" + "database/sql" + "strings" + + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/google/uuid" +) + +type CommandHistory struct { + ID string + SessionID string + Command string + CreatedAt int64 + UpdatedAt int64 +} + +type Service interface { + pubsub.Suscriber[CommandHistory] + Add(ctx context.Context, sessionID, command string) (CommandHistory, error) + ListBySession(ctx context.Context, sessionID string, limit int) ([]CommandHistory, error) + DeleteSessionHistory(ctx context.Context, sessionID string) error +} + +type service struct { + *pubsub.Broker[CommandHistory] + db *sql.DB + q *db.Queries +} + +const MaxHistorySize = 1000 + +func NewService(q *db.Queries, db *sql.DB) Service { + return &service{ + Broker: pubsub.NewBroker[CommandHistory](), + q: q, + db: db, + } +} + +func (s *service) Add(ctx context.Context, sessionID, command string) (CommandHistory, error) { + command = strings.TrimSpace(command) + if command == "" { + return CommandHistory{}, nil + } + + // Get current count for this session + countRow, err := s.q.GetCommandHistoryCount(ctx, db.GetCommandHistoryCountParams{ + SessionID: sessionID, + }) + if err != nil { + return CommandHistory{}, err + } + + // If we're at the limit, remove oldest entries + if int(countRow.Count) >= MaxHistorySize { + history, err := s.q.ListCommandHistoryBySession(ctx, db.ListCommandHistoryBySessionParams{ + SessionID: sessionID, + }) + if err != nil { + return CommandHistory{}, err + } + + // Remove oldest entries to make room + toRemove := int(countRow.Count) - MaxHistorySize + 1 + for i := 0; i < toRemove && i < len(history); i++ { + // Simple deletion - in a more sophisticated implementation, + // we might want to batch delete + if _, err := s.db.ExecContext(ctx, "DELETE FROM command_history WHERE id = ?", history[i].ID); err != nil { + return CommandHistory{}, err + } + } + } + + dbHistory, err := s.q.CreateCommandHistory(ctx, db.CreateCommandHistoryParams{ + ID: uuid.New().String(), + SessionID: sessionID, + Command: command, + }) + if err != nil { + return CommandHistory{}, err + } + + history := CommandHistory{ + ID: dbHistory.ID, + SessionID: dbHistory.SessionID, + Command: dbHistory.Command, + CreatedAt: dbHistory.CreatedAt, + UpdatedAt: dbHistory.UpdatedAt, + } + + s.Publish(pubsub.CreatedEvent, history) + return history, nil +} + +func (s *service) ListBySession(ctx context.Context, sessionID string, limit int) ([]CommandHistory, error) { + if limit <= 0 { + limit = MaxHistorySize + } + + dbHistory, err := s.q.ListLatestCommandHistoryBySession(ctx, db.ListLatestCommandHistoryBySessionParams{ + SessionID: sessionID, + Limit: int64(limit), + }) + if err != nil { + return nil, err + } + + history := make([]CommandHistory, len(dbHistory)) + for i, dbItem := range dbHistory { + // Reverse the slice so callers see commands in chronological order. + history[len(dbHistory)-1-i] = CommandHistory{ + ID: dbItem.ID, + SessionID: dbItem.SessionID, + Command: dbItem.Command, + CreatedAt: dbItem.CreatedAt, + UpdatedAt: dbItem.UpdatedAt, + } + } + return history, nil +} + +func (s *service) DeleteSessionHistory(ctx context.Context, sessionID string) error { + err := s.q.DeleteSessionCommandHistory(ctx, db.DeleteSessionCommandHistoryParams{ + SessionID: sessionID, + }) + if err != nil { + return err + } + // Publish deletion event + s.Publish(pubsub.DeletedEvent, CommandHistory{SessionID: sessionID}) + return nil +} diff --git a/internal/commandhistory/service_test.go b/internal/commandhistory/service_test.go new file mode 100644 index 000000000..358c996b7 --- /dev/null +++ b/internal/commandhistory/service_test.go @@ -0,0 +1,70 @@ +package commandhistory + +import ( + "context" + "testing" + + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/session" + "github.com/stretchr/testify/require" +) + +func setupTestService(t *testing.T) (context.Context, *service, *db.Queries) { + t.Helper() + + ctx := context.Background() + conn, err := db.Connect(ctx, t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, conn.Close()) + }) + + queries := db.New(conn) + svc := NewService(queries, conn) + return ctx, svc.(*service), queries +} + +func TestListBySessionReturnsChronologicalHistory(t *testing.T) { + ctx, svc, queries := setupTestService(t) + + sessionSvc := session.NewService(queries) + sess, err := sessionSvc.Create(ctx, "test session") + require.NoError(t, err) + + // Seed deterministic history with known timestamps. + rows := []struct { + id string + command string + ts int64 + }{ + {id: "cmd-1", command: "first", ts: 1}, + {id: "cmd-2", command: "second", ts: 2}, + {id: "cmd-3", command: "third", ts: 3}, + } + + for _, row := range rows { + _, err := svc.db.ExecContext(ctx, ` + INSERT INTO command_history (id, session_id, command, created_at, updated_at) + VALUES (?, ?, ?, ?, ?)`, + row.id, sess.ID, row.command, row.ts, row.ts, + ) + require.NoError(t, err) + } + + history, err := svc.ListBySession(ctx, sess.ID, 0) + require.NoError(t, err) + require.Len(t, history, 3) + require.Equal(t, []string{"first", "second", "third"}, []string{ + history[0].Command, + history[1].Command, + history[2].Command, + }) + + limitedHistory, err := svc.ListBySession(ctx, sess.ID, 2) + require.NoError(t, err) + require.Len(t, limitedHistory, 2) + require.Equal(t, []string{"second", "third"}, []string{ + limitedHistory[0].Command, + limitedHistory[1].Command, + }) +} diff --git a/internal/db/command_history.sql.go b/internal/db/command_history.sql.go new file mode 100644 index 000000000..06487b69a --- /dev/null +++ b/internal/db/command_history.sql.go @@ -0,0 +1,136 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: command_history.sql + +package db + +import ( + "context" +) + +const createCommandHistory = `-- name: CreateCommandHistory :one +INSERT INTO command_history ( + id, + session_id, + command, + created_at, + updated_at +) VALUES ( + ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') +) +RETURNING id, session_id, command, created_at, updated_at +` + +func (q *Queries) CreateCommandHistory(ctx context.Context, arg CreateCommandHistoryParams) (CommandHistory, error) { + row := q.queryRow(ctx, q.createCommandHistoryStmt, createCommandHistory, + arg.ID, + arg.SessionID, + arg.Command, + ) + var i CommandHistory + err := row.Scan( + &i.ID, + &i.SessionID, + &i.Command, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteSessionCommandHistory = `-- name: DeleteSessionCommandHistory :exec +DELETE FROM command_history +WHERE session_id = ? +` + +func (q *Queries) DeleteSessionCommandHistory(ctx context.Context, arg DeleteSessionCommandHistoryParams) error { + _, err := q.exec(ctx, q.deleteSessionCommandHistoryStmt, deleteSessionCommandHistory, arg.SessionID) + return err +} + +const getCommandHistoryCount = `-- name: GetCommandHistoryCount :one +SELECT COUNT(*) as count +FROM command_history +WHERE session_id = ? +` + +func (q *Queries) GetCommandHistoryCount(ctx context.Context, arg GetCommandHistoryCountParams) (GetCommandHistoryCountRow, error) { + row := q.queryRow(ctx, q.getCommandHistoryCountStmt, getCommandHistoryCount, arg.SessionID) + var i GetCommandHistoryCountRow + err := row.Scan(&i.Count) + return i, err +} + +const listCommandHistoryBySession = `-- name: ListCommandHistoryBySession :many +SELECT id, session_id, command, created_at, updated_at +FROM command_history +WHERE session_id = ? +ORDER BY created_at ASC +` + +func (q *Queries) ListCommandHistoryBySession(ctx context.Context, arg ListCommandHistoryBySessionParams) ([]CommandHistory, error) { + rows, err := q.query(ctx, q.listCommandHistoryBySessionStmt, listCommandHistoryBySession, arg.SessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CommandHistory + for rows.Next() { + var i CommandHistory + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Command, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listLatestCommandHistoryBySession = `-- name: ListLatestCommandHistoryBySession :many +SELECT id, session_id, command, created_at, updated_at +FROM command_history +WHERE session_id = ? +ORDER BY created_at DESC +LIMIT ? +` + +func (q *Queries) ListLatestCommandHistoryBySession(ctx context.Context, arg ListLatestCommandHistoryBySessionParams) ([]CommandHistory, error) { + rows, err := q.query(ctx, q.listLatestCommandHistoryBySessionStmt, listLatestCommandHistoryBySession, arg.SessionID, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CommandHistory + for rows.Next() { + var i CommandHistory + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Command, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} \ No newline at end of file diff --git a/internal/db/db.go b/internal/db/db.go index 62ebe0134..66897fb2e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -24,6 +24,9 @@ func New(db DBTX) *Queries { func Prepare(ctx context.Context, db DBTX) (*Queries, error) { q := Queries{db: db} var err error + if q.createCommandHistoryStmt, err = db.PrepareContext(ctx, createCommandHistory); err != nil { + return nil, fmt.Errorf("error preparing query CreateCommandHistory: %w", err) + } if q.createFileStmt, err = db.PrepareContext(ctx, createFile); err != nil { return nil, fmt.Errorf("error preparing query CreateFile: %w", err) } @@ -39,6 +42,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.deleteMessageStmt, err = db.PrepareContext(ctx, deleteMessage); err != nil { return nil, fmt.Errorf("error preparing query DeleteMessage: %w", err) } + if q.deleteSessionCommandHistoryStmt, err = db.PrepareContext(ctx, deleteSessionCommandHistory); err != nil { + return nil, fmt.Errorf("error preparing query DeleteSessionCommandHistory: %w", err) + } if q.deleteSessionStmt, err = db.PrepareContext(ctx, deleteSession); err != nil { return nil, fmt.Errorf("error preparing query DeleteSession: %w", err) } @@ -48,6 +54,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.deleteSessionMessagesStmt, err = db.PrepareContext(ctx, deleteSessionMessages); err != nil { return nil, fmt.Errorf("error preparing query DeleteSessionMessages: %w", err) } + if q.getCommandHistoryCountStmt, err = db.PrepareContext(ctx, getCommandHistoryCount); err != nil { + return nil, fmt.Errorf("error preparing query GetCommandHistoryCount: %w", err) + } if q.getFileStmt, err = db.PrepareContext(ctx, getFile); err != nil { return nil, fmt.Errorf("error preparing query GetFile: %w", err) } @@ -60,12 +69,18 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil { return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err) } + if q.listCommandHistoryBySessionStmt, err = db.PrepareContext(ctx, listCommandHistoryBySession); err != nil { + return nil, fmt.Errorf("error preparing query ListCommandHistoryBySession: %w", err) + } if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil { return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err) } if q.listFilesBySessionStmt, err = db.PrepareContext(ctx, listFilesBySession); err != nil { return nil, fmt.Errorf("error preparing query ListFilesBySession: %w", err) } + if q.listLatestCommandHistoryBySessionStmt, err = db.PrepareContext(ctx, listLatestCommandHistoryBySession); err != nil { + return nil, fmt.Errorf("error preparing query ListLatestCommandHistoryBySession: %w", err) + } if q.listLatestSessionFilesStmt, err = db.PrepareContext(ctx, listLatestSessionFiles); err != nil { return nil, fmt.Errorf("error preparing query ListLatestSessionFiles: %w", err) } @@ -89,6 +104,11 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { func (q *Queries) Close() error { var err error + if q.createCommandHistoryStmt != nil { + if cerr := q.createCommandHistoryStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing createCommandHistoryStmt: %w", cerr) + } + } if q.createFileStmt != nil { if cerr := q.createFileStmt.Close(); cerr != nil { err = fmt.Errorf("error closing createFileStmt: %w", cerr) @@ -114,6 +134,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing deleteMessageStmt: %w", cerr) } } + if q.deleteSessionCommandHistoryStmt != nil { + if cerr := q.deleteSessionCommandHistoryStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing deleteSessionCommandHistoryStmt: %w", cerr) + } + } if q.deleteSessionStmt != nil { if cerr := q.deleteSessionStmt.Close(); cerr != nil { err = fmt.Errorf("error closing deleteSessionStmt: %w", cerr) @@ -129,6 +154,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing deleteSessionMessagesStmt: %w", cerr) } } + if q.getCommandHistoryCountStmt != nil { + if cerr := q.getCommandHistoryCountStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getCommandHistoryCountStmt: %w", cerr) + } + } if q.getFileStmt != nil { if cerr := q.getFileStmt.Close(); cerr != nil { err = fmt.Errorf("error closing getFileStmt: %w", cerr) @@ -149,6 +179,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr) } } + if q.listCommandHistoryBySessionStmt != nil { + if cerr := q.listCommandHistoryBySessionStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listCommandHistoryBySessionStmt: %w", cerr) + } + } if q.listFilesByPathStmt != nil { if cerr := q.listFilesByPathStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr) @@ -159,6 +194,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing listFilesBySessionStmt: %w", cerr) } } + if q.listLatestCommandHistoryBySessionStmt != nil { + if cerr := q.listLatestCommandHistoryBySessionStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listLatestCommandHistoryBySessionStmt: %w", cerr) + } + } if q.listLatestSessionFilesStmt != nil { if cerr := q.listLatestSessionFilesStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listLatestSessionFilesStmt: %w", cerr) @@ -226,48 +266,58 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar } type Queries struct { - db DBTX - tx *sql.Tx - createFileStmt *sql.Stmt - createMessageStmt *sql.Stmt - createSessionStmt *sql.Stmt - deleteFileStmt *sql.Stmt - deleteMessageStmt *sql.Stmt - deleteSessionStmt *sql.Stmt - deleteSessionFilesStmt *sql.Stmt - deleteSessionMessagesStmt *sql.Stmt - getFileStmt *sql.Stmt - getFileByPathAndSessionStmt *sql.Stmt - getMessageStmt *sql.Stmt - getSessionByIDStmt *sql.Stmt - listFilesByPathStmt *sql.Stmt - listFilesBySessionStmt *sql.Stmt - listLatestSessionFilesStmt *sql.Stmt - listMessagesBySessionStmt *sql.Stmt - listNewFilesStmt *sql.Stmt - listSessionsStmt *sql.Stmt - updateMessageStmt *sql.Stmt - updateSessionStmt *sql.Stmt + db DBTX + tx *sql.Tx + createCommandHistoryStmt *sql.Stmt + createFileStmt *sql.Stmt + createMessageStmt *sql.Stmt + createSessionStmt *sql.Stmt + deleteFileStmt *sql.Stmt + deleteMessageStmt *sql.Stmt + deleteSessionCommandHistoryStmt *sql.Stmt + deleteSessionStmt *sql.Stmt + deleteSessionFilesStmt *sql.Stmt + deleteSessionMessagesStmt *sql.Stmt + getCommandHistoryCountStmt *sql.Stmt + getFileStmt *sql.Stmt + getFileByPathAndSessionStmt *sql.Stmt + getMessageStmt *sql.Stmt + getSessionByIDStmt *sql.Stmt + listCommandHistoryBySessionStmt *sql.Stmt + listFilesByPathStmt *sql.Stmt + listFilesBySessionStmt *sql.Stmt + listLatestCommandHistoryBySessionStmt *sql.Stmt + listLatestSessionFilesStmt *sql.Stmt + listMessagesBySessionStmt *sql.Stmt + listNewFilesStmt *sql.Stmt + listSessionsStmt *sql.Stmt + updateMessageStmt *sql.Stmt + updateSessionStmt *sql.Stmt } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ db: tx, tx: tx, + createCommandHistoryStmt: q.createCommandHistoryStmt, createFileStmt: q.createFileStmt, createMessageStmt: q.createMessageStmt, createSessionStmt: q.createSessionStmt, deleteFileStmt: q.deleteFileStmt, deleteMessageStmt: q.deleteMessageStmt, + deleteSessionCommandHistoryStmt: q.deleteSessionCommandHistoryStmt, deleteSessionStmt: q.deleteSessionStmt, deleteSessionFilesStmt: q.deleteSessionFilesStmt, deleteSessionMessagesStmt: q.deleteSessionMessagesStmt, + getCommandHistoryCountStmt: q.getCommandHistoryCountStmt, getFileStmt: q.getFileStmt, getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, getMessageStmt: q.getMessageStmt, getSessionByIDStmt: q.getSessionByIDStmt, + listCommandHistoryBySessionStmt: q.listCommandHistoryBySessionStmt, listFilesByPathStmt: q.listFilesByPathStmt, listFilesBySessionStmt: q.listFilesBySessionStmt, + listLatestCommandHistoryBySessionStmt: q.listLatestCommandHistoryBySessionStmt, listLatestSessionFilesStmt: q.listLatestSessionFilesStmt, listMessagesBySessionStmt: q.listMessagesBySessionStmt, listNewFilesStmt: q.listNewFilesStmt, diff --git a/internal/db/migrations/20251018000000_add_command_history.sql b/internal/db/migrations/20251018000000_add_command_history.sql new file mode 100644 index 000000000..5e1c4a5ae --- /dev/null +++ b/internal/db/migrations/20251018000000_add_command_history.sql @@ -0,0 +1,29 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE IF NOT EXISTS command_history ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + command TEXT NOT NULL, + created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds + updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds + FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_command_history_session_id ON command_history(session_id); +CREATE INDEX IF NOT EXISTS idx_command_history_created_at ON command_history(created_at); + +CREATE TRIGGER IF NOT EXISTS update_command_history_updated_at +AFTER UPDATE ON command_history +BEGIN + UPDATE command_history SET updated_at = strftime('%s', 'now') + WHERE id = new.id; +END; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TRIGGER IF EXISTS update_command_history_updated_at; +DROP INDEX IF EXISTS idx_command_history_created_at; +DROP INDEX IF EXISTS idx_command_history_session_id; +DROP TABLE IF EXISTS command_history; +-- +goose StatementEnd \ No newline at end of file diff --git a/internal/db/models.go b/internal/db/models.go index ec3e6e10a..61264aa52 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -42,3 +42,38 @@ type Session struct { CreatedAt int64 `json:"created_at"` SummaryMessageID sql.NullString `json:"summary_message_id"` } + +type CommandHistory struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + Command string `json:"command"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type CreateCommandHistoryParams struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + Command string `json:"command"` +} + +type DeleteSessionCommandHistoryParams struct { + SessionID string `json:"session_id"` +} + +type GetCommandHistoryCountParams struct { + SessionID string `json:"session_id"` +} + +type ListCommandHistoryBySessionParams struct { + SessionID string `json:"session_id"` +} + +type ListLatestCommandHistoryBySessionParams struct { + SessionID string `json:"session_id"` + Limit int64 `json:"limit"` +} + +type GetCommandHistoryCountRow struct { + Count int64 `json:"count"` +} diff --git a/internal/db/querier.go b/internal/db/querier.go index 472137273..369929d87 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -9,20 +9,25 @@ import ( ) type Querier interface { + CreateCommandHistory(ctx context.Context, arg CreateCommandHistoryParams) (CommandHistory, error) CreateFile(ctx context.Context, arg CreateFileParams) (File, error) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) DeleteFile(ctx context.Context, id string) error DeleteMessage(ctx context.Context, id string) error DeleteSession(ctx context.Context, id string) error + DeleteSessionCommandHistory(ctx context.Context, arg DeleteSessionCommandHistoryParams) error DeleteSessionFiles(ctx context.Context, sessionID string) error DeleteSessionMessages(ctx context.Context, sessionID string) error + GetCommandHistoryCount(ctx context.Context, arg GetCommandHistoryCountParams) (GetCommandHistoryCountRow, error) GetFile(ctx context.Context, id string) (File, error) GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) GetMessage(ctx context.Context, id string) (Message, error) GetSessionByID(ctx context.Context, id string) (Session, error) + ListCommandHistoryBySession(ctx context.Context, arg ListCommandHistoryBySessionParams) ([]CommandHistory, error) ListFilesByPath(ctx context.Context, path string) ([]File, error) ListFilesBySession(ctx context.Context, sessionID string) ([]File, error) + ListLatestCommandHistoryBySession(ctx context.Context, arg ListLatestCommandHistoryBySessionParams) ([]CommandHistory, error) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) ListNewFiles(ctx context.Context) ([]File, error) diff --git a/internal/db/sql/command_history.sql b/internal/db/sql/command_history.sql new file mode 100644 index 000000000..e616cf8e0 --- /dev/null +++ b/internal/db/sql/command_history.sql @@ -0,0 +1,33 @@ +-- name: CreateCommandHistory :one +INSERT INTO command_history ( + id, + session_id, + command, + created_at, + updated_at +) VALUES ( + ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') +) +RETURNING *; + +-- name: ListCommandHistoryBySession :many +SELECT * +FROM command_history +WHERE session_id = ? +ORDER BY created_at ASC; + +-- name: ListLatestCommandHistoryBySession :many +SELECT * +FROM command_history +WHERE session_id = ? +ORDER BY created_at DESC +LIMIT ?; + +-- name: DeleteSessionCommandHistory :exec +DELETE FROM command_history +WHERE session_id = ?; + +-- name: GetCommandHistoryCount :one +SELECT COUNT(*) as count +FROM command_history +WHERE session_id = ?; \ No newline at end of file diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index f70a0a3db..ad61577b4 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -11,12 +11,14 @@ import ( "runtime" "slices" "strings" + "time" "unicode" "github.com/charmbracelet/bubbles/v2/key" "github.com/charmbracelet/bubbles/v2/textarea" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/commandhistory" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" @@ -67,6 +69,12 @@ type editorCmp struct { currentQuery string completionsStartIndex int isCompletionsOpen bool + + // Command history fields + history []commandhistory.CommandHistory + historyIndex int + tempInput string + isInHistoryMode bool } var DeleteKeyMaps = DeleteAttachmentKeyMaps{ @@ -141,29 +149,38 @@ func (m *editorCmp) send() tea.Cmd { value := m.textarea.Value() value = strings.TrimSpace(value) + // Add to history before processing + cmds := []tea.Cmd{m.addToHistory(value)} + switch value { case "exit", "quit": m.textarea.Reset() - return util.CmdHandler(dialogs.OpenDialogMsg{Model: quit.NewQuitDialog()}) + m.historyIndex = len(m.history) + m.isInHistoryMode = false + cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{Model: quit.NewQuitDialog()})) + return tea.Batch(cmds...) } m.textarea.Reset() + m.historyIndex = len(m.history) + m.isInHistoryMode = false attachments := m.attachments m.attachments = nil if value == "" { - return nil + return tea.Batch(cmds...) } // Change the placeholder when sending a new message. m.randomizePlaceholders() - return tea.Batch( + cmds = append(cmds, util.CmdHandler(chat.SendMsg{ Text: value, Attachments: attachments, }), ) + return tea.Batch(cmds...) } func (m *editorCmp) repositionCompletions() tea.Msg { @@ -296,6 +313,15 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } } + // Handle history navigation before other keys + if m.textarea.Focused() && !m.isCompletionsOpen { + if key.Matches(msg, m.keyMap.HistoryUp) { + return m, m.navigateHistory(-1) + } + if key.Matches(msg, m.keyMap.HistoryDown) { + return m, m.navigateHistory(1) + } + } if key.Matches(msg, m.keyMap.OpenEditor) { if m.app.CoderAgent.IsSessionBusy(m.session.ID) { return m, util.ReportWarn("Agent is working, please wait...") @@ -391,6 +417,85 @@ func (m *editorCmp) Cursor() *tea.Cursor { return cursor } +// Add method to load history for session +func (m *editorCmp) loadHistory(ctx context.Context) error { + if m.session.ID == "" { + return nil + } + + history, err := m.app.CommandHistory.ListBySession(ctx, m.session.ID, 0) // 0 = no limit + if err != nil { + return err + } + + m.history = history + m.historyIndex = len(m.history) + return nil +} + +// Add method to navigate history +func (m *editorCmp) navigateHistory(direction int) tea.Cmd { + if len(m.history) == 0 { + return nil + } + + // Save current input when first entering history mode + if !m.isInHistoryMode { + m.tempInput = m.textarea.Value() + m.isInHistoryMode = true + } + + newIndex := m.historyIndex + direction + + if direction < 0 { // Up arrow + if newIndex >= 0 { + m.historyIndex = newIndex + // Since history is in descending order (latest first), + // we can access directly by index + m.textarea.SetValue(m.history[m.historyIndex].Command) + m.textarea.CursorEnd() + } + } else { // Down arrow + if newIndex < len(m.history) { + m.historyIndex = newIndex + m.textarea.SetValue(m.history[m.historyIndex].Command) + m.textarea.CursorEnd() + } else { + // Return to saved input + m.textarea.SetValue(m.tempInput) + m.textarea.CursorEnd() + m.isInHistoryMode = false + m.historyIndex = len(m.history) + } + } + + return nil +} + +// Add method to add command to history +func (m *editorCmp) addToHistory(command string) tea.Cmd { + if m.session.ID == "" || strings.TrimSpace(command) == "" { + return nil + } + + return func() tea.Msg { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := m.app.CommandHistory.Add(ctx, m.session.ID, command) + if err != nil { + return util.ReportError(err) + } + + // Reload history to include the new command + if err := m.loadHistory(ctx); err != nil { + return util.ReportError(err) + } + + return nil + } +} + var readyPlaceholders = [...]string{ "Ready!", "Ready...", @@ -528,7 +633,17 @@ func (c *editorCmp) Bindings() []key.Binding { // we need to move some functionality to the page level func (c *editorCmp) SetSession(session session.Session) tea.Cmd { c.session = session - return nil + c.history = []commandhistory.CommandHistory{} + c.historyIndex = 0 + c.isInHistoryMode = false + c.tempInput = "" + + return func() tea.Msg { + if err := c.loadHistory(context.Background()); err != nil { + return util.ReportError(err) + } + return nil + } } func (c *editorCmp) IsCompletionsOpen() bool { @@ -575,9 +690,13 @@ func New(app *app.App) Editor { ta.Focus() e := &editorCmp{ // TODO: remove the app instance from here - app: app, - textarea: ta, - keyMap: DefaultEditorKeyMap(), + app: app, + textarea: ta, + keyMap: DefaultEditorKeyMap(), + history: []commandhistory.CommandHistory{}, + historyIndex: 0, + tempInput: "", + isInHistoryMode: false, } e.setEditorPrompt() diff --git a/internal/tui/components/chat/editor/history_test.go b/internal/tui/components/chat/editor/history_test.go new file mode 100644 index 000000000..b8654a037 --- /dev/null +++ b/internal/tui/components/chat/editor/history_test.go @@ -0,0 +1,155 @@ +package editor + +import ( + "testing" + + "github.com/charmbracelet/bubbles/v2/textarea" + "github.com/charmbracelet/crush/internal/commandhistory" + "github.com/charmbracelet/crush/internal/session" +) + +func TestCommandHistoryNavigation(t *testing.T) { + // Create test editor with properly initialized textarea + editor := &editorCmp{ + history: []commandhistory.CommandHistory{ + {Command: "first command"}, + {Command: "second command"}, + {Command: "third command"}, + }, + historyIndex: 3, + isInHistoryMode: false, + } + + // Initialize textarea properly + editor.textarea = textarea.New() + editor.textarea.SetValue("current input") + editor.textarea.Focus() + + // Test up arrow navigation - should go to most recent (third command) + cmd := editor.navigateHistory(-1) + if cmd != nil { + t.Error("Expected nil cmd for navigation") + } + + if !editor.isInHistoryMode { + t.Error("Expected to be in history mode") + } + + // The first up should show "third command" (most recent) + if editor.textarea.Value() != "third command" { + t.Errorf("Expected 'third command', got '%s'", editor.textarea.Value()) + } + + // Test another up arrow - should go to second command + cmd = editor.navigateHistory(-1) + if cmd != nil { + t.Error("Expected nil cmd for navigation") + } + + if editor.textarea.Value() != "second command" { + t.Errorf("Expected 'second command', got '%s'", editor.textarea.Value()) + } + + // Test another up arrow - should go to first command + cmd = editor.navigateHistory(-1) + if cmd != nil { + t.Error("Expected nil cmd for navigation") + } + + if editor.textarea.Value() != "first command" { + t.Errorf("Expected 'first command', got '%s'", editor.textarea.Value()) + } + + // Test down arrow - should go to second command + cmd = editor.navigateHistory(1) + if cmd != nil { + t.Error("Expected nil cmd for navigation") + } + + if editor.textarea.Value() != "second command" { + t.Errorf("Expected 'second command', got '%s'", editor.textarea.Value()) + } + + // Test down arrow - should go to third command + cmd = editor.navigateHistory(1) + if cmd != nil { + t.Error("Expected nil cmd for navigation") + } + + if editor.textarea.Value() != "third command" { + t.Errorf("Expected 'third command', got '%s'", editor.textarea.Value()) + } + + // Test down arrow - should exit history mode and return to current input + cmd = editor.navigateHistory(1) + if cmd != nil { + t.Error("Expected nil cmd for navigation") + } + + if editor.isInHistoryMode { + t.Error("Expected to exit history mode") + } + + if editor.textarea.Value() != "current input" { + t.Errorf("Expected 'current input', got '%s'", editor.textarea.Value()) + } +} + +func TestAddToHistory(t *testing.T) { + // Mock app and session for testing + editor := &editorCmp{ + session: session.Session{ID: "test-session"}, + } + + // Test with empty command + cmd := editor.addToHistory("") + if cmd != nil { + // Should return nil for empty command + // In real implementation this would check session.ID and command content + t.Error("Expected nil cmd for empty command") + } + + // Test with valid command + cmd = editor.addToHistory("test command") + if cmd == nil { + t.Error("Expected non-nil cmd for valid command") + } +} + +func TestSetSessionResetsHistory(t *testing.T) { + editor := &editorCmp{ + history: []commandhistory.CommandHistory{ + {Command: "old command"}, + }, + historyIndex: 1, + isInHistoryMode: true, + tempInput: "temp", + } + + newSession := session.Session{ID: "new-session"} + cmd := editor.SetSession(newSession) + + if editor.session.ID != "new-session" { + t.Error("Expected session to be updated") + } + + if len(editor.history) != 0 { + t.Error("Expected history to be reset") + } + + if editor.historyIndex != 0 { + t.Error("Expected history index to be reset") + } + + if editor.isInHistoryMode { + t.Error("Expected history mode to be reset") + } + + if editor.tempInput != "" { + t.Error("Expected temp input to be reset") + } + + if cmd == nil { + t.Error("Expected non-nil cmd to load history") + } +} \ No newline at end of file diff --git a/internal/tui/components/chat/editor/keys.go b/internal/tui/components/chat/editor/keys.go index 8bc8b2354..4998a592c 100644 --- a/internal/tui/components/chat/editor/keys.go +++ b/internal/tui/components/chat/editor/keys.go @@ -9,6 +9,8 @@ type EditorKeyMap struct { SendMessage key.Binding OpenEditor key.Binding Newline key.Binding + HistoryUp key.Binding + HistoryDown key.Binding } func DefaultEditorKeyMap() EditorKeyMap { @@ -32,6 +34,14 @@ func DefaultEditorKeyMap() EditorKeyMap { // to reflect that. key.WithHelp("ctrl+j", "newline"), ), + HistoryUp: key.NewBinding( + key.WithKeys("up"), + key.WithHelp("↑", "previous command"), + ), + HistoryDown: key.NewBinding( + key.WithKeys("down"), + key.WithHelp("↓", "next command"), + ), } } @@ -42,6 +52,8 @@ func (k EditorKeyMap) KeyBindings() []key.Binding { k.SendMessage, k.OpenEditor, k.Newline, + k.HistoryUp, + k.HistoryDown, AttachmentsKeyMaps.AttachmentDeleteMode, AttachmentsKeyMaps.DeleteAllAttachments, AttachmentsKeyMaps.Escape,