Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- **Session-Based:** maintain multiple work sessions and contexts per project
- **LSP-Enhanced:** Crush uses LSPs for additional context, just like you do
- **Extensible:** add capabilities via MCPs (`http`, `stdio`, and `sse`)
- **[Hooks](./internal/hooks/HOOKS.md):** execute custom shell commands at lifecycle events
- **Works Everywhere:** first-class support in every terminal on macOS, Linux, Windows (PowerShell and WSL), FreeBSD, OpenBSD, and NetBSD

## Installation
Expand Down
100 changes: 100 additions & 0 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package agent
import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"log/slog"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/hooks"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/session"
Expand Down Expand Up @@ -83,6 +85,7 @@ type sessionAgent struct {
messages message.Service
disableAutoSummarize bool
isYolo bool
hooks *hooks.Executor

messageQueue *csync.Map[string, []SessionAgentCall]
activeRequests *csync.Map[string, context.CancelFunc]
Expand All @@ -98,6 +101,7 @@ type SessionAgentOptions struct {
Sessions session.Service
Messages message.Service
Tools []fantasy.AgentTool
Hooks *hooks.Executor
}

func NewSessionAgent(
Expand All @@ -113,6 +117,7 @@ func NewSessionAgent(
disableAutoSummarize: opts.DisableAutoSummarize,
tools: opts.Tools,
isYolo: opts.IsYolo,
hooks: opts.Hooks,
messageQueue: csync.NewMap[string, []SessionAgentCall](),
activeRequests: csync.NewMap[string, context.CancelFunc](),
}
Expand Down Expand Up @@ -175,6 +180,19 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
return nil, err
}

// Execute UserPromptSubmit hook
if a.hooks != nil {
if err := a.hooks.Execute(ctx, hooks.HookContext{
EventType: config.UserPromptSubmit,
SessionID: call.SessionID,
UserPrompt: call.Prompt,
Provider: a.largeModel.ModelCfg.Provider,
Model: a.largeModel.ModelCfg.Model,
}); err != nil {
slog.Debug("user_prompt_submit hook execution failed", "error", err)
}
}

// Add the session to the context.
ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)

Expand Down Expand Up @@ -307,6 +325,25 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
// TODO: implement
},
OnToolCall: func(tc fantasy.ToolCallContent) error {
// Execute PreToolUse hook - blocks tool execution on error
if a.hooks != nil {
toolInput := make(map[string]any)
if err := json.Unmarshal([]byte(tc.Input), &toolInput); err != nil {
slog.Warn("Failed to unmarshal tool input for PreToolUse hook", "error", err, "tool", tc.ToolName)
}
if err := a.hooks.Execute(genCtx, hooks.HookContext{
EventType: config.PreToolUse,
SessionID: call.SessionID,
ToolName: tc.ToolName,
ToolInput: toolInput,
MessageID: currentAssistant.ID,
Provider: a.largeModel.ModelCfg.Provider,
Model: a.largeModel.ModelCfg.Model,
}); err != nil {
return fmt.Errorf("PreToolUse hook blocked tool execution: %w", err)
}
}

toolCall := message.ToolCall{
ID: tc.ToolCallID,
Name: tc.ToolName,
Expand Down Expand Up @@ -335,6 +372,36 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
case fantasy.ToolResultContentTypeMedia:
// TODO: handle this message type
}

// Execute PostToolUse hook
if a.hooks != nil {
toolInput := make(map[string]any)
// Try to get tool input from the assistant message
toolCalls := currentAssistant.ToolCalls()
for _, tc := range toolCalls {
if tc.ID == result.ToolCallID {
if err := json.Unmarshal([]byte(tc.Input), &toolInput); err != nil {
slog.Debug("Failed to unmarshal tool input for PostToolUse hook", "error", err, "tool", result.ToolName)
}
break
}
}

if err := a.hooks.Execute(genCtx, hooks.HookContext{
EventType: config.PostToolUse,
SessionID: call.SessionID,
ToolName: result.ToolName,
ToolInput: toolInput,
ToolResult: resultContent,
ToolError: isError,
MessageID: currentAssistant.ID,
Provider: a.largeModel.ModelCfg.Provider,
Model: a.largeModel.ModelCfg.Model,
}); err != nil {
slog.Debug("post_tool_use hook execution failed", "error", err)
}
}

toolResult := message.ToolResult{
ToolCallID: result.ToolCallID,
Name: result.ToolName,
Expand Down Expand Up @@ -476,6 +543,27 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}
wg.Wait()

// Execute Stop hook
if a.hooks != nil && result != nil {
var totalTokens, inputTokens int64
for _, step := range result.Steps {
totalTokens += step.Usage.TotalTokens
inputTokens += step.Usage.InputTokens
}

if err := a.hooks.Execute(ctx, hooks.HookContext{
EventType: config.Stop,
SessionID: call.SessionID,
MessageID: currentAssistant.ID,
Provider: a.largeModel.ModelCfg.Provider,
Model: a.largeModel.ModelCfg.Model,
TokensUsed: totalTokens,
TokensInput: inputTokens,
}); err != nil {
slog.Debug("stop hook execution failed", "error", err)
}
}

if shouldSummarize {
a.activeRequests.Del(call.SessionID)
if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
Expand Down Expand Up @@ -525,6 +613,18 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return nil
}

// Execute PreCompact hook
if a.hooks != nil {
if err := a.hooks.Execute(ctx, hooks.HookContext{
EventType: config.PreCompact,
SessionID: sessionID,
Provider: a.largeModel.ModelCfg.Provider,
Model: a.largeModel.ModelCfg.Model,
}); err != nil {
slog.Debug("pre_compact hook execution failed", "error", err)
}
}

aiMsgs, _ := a.preparePrompt(msgs)

genCtx, cancel := context.WithCancel(ctx)
Expand Down
17 changes: 17 additions & 0 deletions internal/agent/agent_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"

"charm.land/fantasy"

"github.com/charmbracelet/crush/internal/agent/prompt"
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/hooks"
)

//go:embed templates/agent_tool.md
Expand Down Expand Up @@ -104,6 +106,21 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
}

// Execute SubagentStop hook
if c.hooks != nil {
if err := c.hooks.Execute(ctx, hooks.HookContext{
EventType: config.SubagentStop,
SessionID: sessionID,
ToolName: AgentToolName,
MessageID: agentMessageID,
Provider: model.ModelCfg.Provider,
Model: model.ModelCfg.Model,
}); err != nil {
slog.Debug("subagent_stop hook execution failed", "error", err)
}
}

return fantasy.NewTextResponse(result.Response.Content.Text()), nil
}), nil
}
15 changes: 13 additions & 2 deletions internal/agent/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func testEnv(t *testing.T) fakeEnv {
sessions := session.NewService(q)
messages := message.NewService(q)

permissions := permission.NewPermissionService(workingDir, true, []string{})
permissions := permission.NewPermissionService(workingDir, true, []string{}, nil)
history := history.NewService(q, conn)
lspClients := csync.NewMap[string, *lsp.Client]()

Expand Down Expand Up @@ -149,7 +149,18 @@ func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPro
DefaultMaxTokens: 10000,
},
}
agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
agent := NewSessionAgent(SessionAgentOptions{
LargeModel: largeModel,
SmallModel: smallModel,
SystemPromptPrefix: "",
SystemPrompt: systemPrompt,
DisableAutoSummarize: false,
IsYolo: true,
Sessions: env.sessions,
Messages: env.messages,
Tools: tools,
Hooks: nil,
})
return agent
}

Expand Down
23 changes: 14 additions & 9 deletions internal/agent/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/hooks"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
Expand Down Expand Up @@ -61,6 +62,7 @@ type coordinator struct {
permissions permission.Service
history history.Service
lspClients *csync.Map[string, *lsp.Client]
hooks *hooks.Executor

currentAgent SessionAgent
agents map[string]SessionAgent
Expand All @@ -76,6 +78,7 @@ func NewCoordinator(
permissions permission.Service,
history history.Service,
lspClients *csync.Map[string, *lsp.Client],
hooksExecutor *hooks.Executor,
) (Coordinator, error) {
c := &coordinator{
cfg: cfg,
Expand All @@ -84,6 +87,7 @@ func NewCoordinator(
permissions: permissions,
history: history,
lspClients: lspClients,
hooks: hooksExecutor,
agents: make(map[string]SessionAgent),
}

Expand Down Expand Up @@ -287,15 +291,16 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age

largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
result := NewSessionAgent(SessionAgentOptions{
large,
small,
largeProviderCfg.SystemPromptPrefix,
systemPrompt,
c.cfg.Options.DisableAutoSummarize,
c.permissions.SkipRequests(),
c.sessions,
c.messages,
nil,
LargeModel: large,
SmallModel: small,
SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
SystemPrompt: systemPrompt,
DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
IsYolo: c.permissions.SkipRequests(),
Sessions: c.sessions,
Messages: c.messages,
Tools: nil,
Hooks: c.hooks,
})
c.readyWg.Go(func() error {
tools, err := c.buildTools(ctx, agent)
Expand Down
10 changes: 9 additions & 1 deletion internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/hooks"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
Expand All @@ -47,6 +48,7 @@ type App struct {
LSPClients *csync.Map[string, *lsp.Client]

config *config.Config
hooks *hooks.Executor

serviceEventsWG *sync.WaitGroup
eventsCtx context.Context
Expand All @@ -70,16 +72,20 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
allowedTools = cfg.Permissions.AllowedTools
}

// Initialize hooks executor
hooksExecutor := hooks.NewExecutor(cfg.Hooks, cfg.WorkingDir())

app := &App{
Sessions: sessions,
Messages: messages,
History: files,
Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools, hooksExecutor),
LSPClients: csync.NewMap[string, *lsp.Client](),

globalCtx: ctx,

config: cfg,
hooks: hooksExecutor,

events: make(chan tea.Msg, 100),
serviceEventsWG: &sync.WaitGroup{},
Expand Down Expand Up @@ -313,6 +319,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
if coderAgentCfg.ID == "" {
return fmt.Errorf("coder agent configuration is missing")
}

var err error
app.AgentCoordinator, err = agent.NewCoordinator(
ctx,
Expand All @@ -322,6 +329,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
app.Permissions,
app.History,
app.LSPClients,
app.hooks,
)
if err != nil {
slog.Error("Failed to create coder agent", "err", err)
Expand Down
2 changes: 2 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ type Config struct {

Tools Tools `json:"tools,omitzero" jsonschema:"description=Tool configurations"`

Hooks HookConfig `json:"hooks,omitempty" jsonschema:"description=Hook configurations for lifecycle events"`

Agents map[string]Agent `json:"-"`

// Internal
Expand Down
Loading
Loading