Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CRUSH.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ func TestYourFunction(t *testing.T) {
- You can also use `task fmt` to run `gofumpt -w .` on the entire project,
as long as `gofumpt` is on the `PATH`.

## Comments

- Comments that live on their own lines should start with capital letters and
end with periods. Wrap comments at 78 columns.

## Committing

- ALWAYS use semantic commits (`fix:`, `feat:`, `chore:`, `refactor:`, `docs:`, `sec:`, etc).
- Try to keep commits to one line, not including your attribution. Only use
multi-line commits when additional context is truly necessary.
60 changes: 38 additions & 22 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
// Package agent is the core orchestration layer for Crush AI agents.
//
// It provides session-based AI agent functionality for managing
// conversations, tool execution, and message handling. It coordinates
// interactions between language models, messages, sessions, and tools while
// handling features like automatic summarization, queuing, and token
// management.
package agent

import (
Expand Down Expand Up @@ -131,7 +138,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}

if len(a.tools) > 0 {
// add anthropic caching to the last tool
// Add Anthropic caching to the last tool.
a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
}

Expand All @@ -153,7 +160,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}

var wg sync.WaitGroup
// Generate title if first message
// Generate title if first message.
if len(msgs) == 0 {
wg.Go(func() {
sessionLock.Lock()
Expand All @@ -162,13 +169,13 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
})
}

// Add the user message to the session
// Add the user message to the session.
_, err = a.createUserMessage(ctx, call)
if err != nil {
return nil, err
}

// add the session to the context
// Add the session to the context.
ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)

genCtx, cancel := context.WithCancel(ctx)
Expand All @@ -195,10 +202,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
PresencePenalty: call.PresencePenalty,
TopK: call.TopK,
FrequencyPenalty: call.FrequencyPenalty,
// Before each step create the new assistant message
// Before each step create a new assistant message.
PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
prepared.Messages = options.Messages
// reset all cached items
// Reset all cached items.
for i := range prepared.Messages {
prepared.Messages[i].ProviderOptions = nil
}
Expand All @@ -216,14 +223,14 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
lastSystemRoleInx := 0
systemMessageUpdated := false
for i, msg := range prepared.Messages {
// only add cache control to the last message
// Only add cache control to the last message.
if msg.Role == fantasy.MessageRoleSystem {
lastSystemRoleInx = i
} else if !systemMessageUpdated {
prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
systemMessageUpdated = true
}
// than add cache control to the last 2 messages
// Than add cache control to the last 2 messages.
if i > len(prepared.Messages)-3 {
prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
}
Expand Down Expand Up @@ -276,6 +283,13 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
return a.messages.Update(genCtx, *currentAssistant)
},
OnTextDelta: func(id string, text string) error {
// Strip leading newline from initial text content. This is is
// particularly important in non-interactive mode where leading
// newlines are very visible.
if len(currentAssistant.Parts) == 0 {
text = strings.TrimPrefix(text, "\n")
}

currentAssistant.AppendContent(text)
return a.messages.Update(genCtx, *currentAssistant)
},
Expand Down Expand Up @@ -387,10 +401,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
if currentAssistant == nil {
return result, err
}
// Ensure we finish thinking on error to close the reasoning state
// Ensure we finish thinking on error to close the reasoning state.
currentAssistant.FinishThinking()
toolCalls := currentAssistant.ToolCalls()
// INFO: we use the parent context here because the genCtx has been cancelled
// INFO: we use the parent context here because the genCtx has been cancelled.
msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
if createErr != nil {
return nil, createErr
Expand Down Expand Up @@ -452,7 +466,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
} else {
currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
}
// INFO: we use the parent context here because the genCtx has been cancelled
// Note: we use the parent context here because the genCtx has been
// cancelled.
updateErr := a.messages.Update(ctx, *currentAssistant)
if updateErr != nil {
return nil, updateErr
Expand All @@ -466,7 +481,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
return nil, summarizeErr
}
// if the agent was not done...
// If the agent wasn't done...
if len(currentAssistant.ToolCalls()) > 0 {
existing, ok := a.messageQueue.Get(call.SessionID)
if !ok {
Expand All @@ -478,15 +493,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
}
}

// release active request before processing queued messages
// Release active request before processing queued messages.
a.activeRequests.Del(call.SessionID)
cancel()

queuedMessages, ok := a.messageQueue.Get(call.SessionID)
if !ok || len(queuedMessages) == 0 {
return result, err
}
// there are queued messages restart the loop
// There are queued messages restart the loop.
firstQueuedMessage := queuedMessages[0]
a.messageQueue.Set(call.SessionID, queuedMessages[1:])
return a.Run(ctx, firstQueuedMessage)
Expand All @@ -506,7 +521,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return err
}
if len(msgs) == 0 {
// nothing to summarize
// Nothing to summarize.
return nil
}

Expand Down Expand Up @@ -546,7 +561,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
return a.messages.Update(genCtx, summaryMessage)
},
OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
// handle anthropic signature
// Handle anthropic signature.
if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
summaryMessage.AppendReasoningSignature(signature.Signature)
Expand All @@ -563,7 +578,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
if err != nil {
isCancelErr := errors.Is(err, context.Canceled)
if isCancelErr {
// User cancelled summarize we need to remove the summary message
// User cancelled summarize we need to remove the summary message.
deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
return deleteErr
}
Expand All @@ -590,7 +605,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan

a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)

// just in case get just the last usage
// Just in case, get just the last usage info.
usage := resp.Response.Usage
currentSession.SummaryMessageID = summaryMessage.ID
currentSession.CompletionTokens = usage.OutputTokens
Expand Down Expand Up @@ -636,7 +651,8 @@ func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...mess
if len(m.Parts) == 0 {
continue
}
// Assistant message without content or tool calls (cancelled before it returned anything)
// Assistant message without content or tool calls (cancelled before it
// returned anything).
if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
continue
}
Expand Down Expand Up @@ -711,7 +727,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi

title = strings.ReplaceAll(title, "\n", " ")

// remove thinking tags if present
// Remove thinking tags if present.
if idx := strings.Index(title, "</think>"); idx > 0 {
title = title[idx+len("</think>"):]
}
Expand Down Expand Up @@ -777,13 +793,13 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
}

func (a *sessionAgent) Cancel(sessionID string) {
// Cancel regular requests
// Cancel regular requests.
if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
slog.Info("Request cancellation initiated", "session_id", sessionID)
cancel()
}

// Also check for summarize requests
// Also check for summarize requests.
if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
slog.Info("Summarize cancellation initiated", "session_id", sessionID)
cancel()
Expand Down
57 changes: 46 additions & 11 deletions internal/app/app.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
// Package app wires together services, coordinates agents, and manages
// application lifecycle.
package app

import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"log/slog"
"os"
"sync"
"time"

Expand All @@ -24,7 +28,11 @@ import (
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/tui/components/anim"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/x/ansi"
"github.com/charmbracelet/x/exp/charmtone"
)

type App struct {
Expand Down Expand Up @@ -101,17 +109,35 @@ func (app *App) Config() *config.Config {
return app.config
}

// RunNonInteractive handles the execution flow when a prompt is provided via
// CLI flag.
func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
// RunNonInteractive runs the application in non-interactive mode with the
// given prompt, printing to stdout.
func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt string, quiet bool) error {
slog.Info("Running in non-interactive mode")

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var spinner *format.Spinner
if !quiet {
spinner = format.NewSpinner(ctx, cancel, "Generating")
t := styles.CurrentTheme()

// Detect background color to set the appropriate color for the
// spinner's 'Generating...' text. Without this, that text would be
// unreadable in light terminals.
hasDarkBG := true
if f, ok := output.(*os.File); ok {
hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, f)
}
defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)

spinner = format.NewSpinner(ctx, cancel, anim.Settings{
Size: 10,
Label: "Generating",
LabelColor: defaultFG,
GradColorA: t.Primary,
GradColorB: t.Secondary,
CycleColors: true,
})
spinner.Start()
}

Expand All @@ -125,7 +151,7 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool
defer stopSpinner()

const maxPromptLengthForTitle = 100
titlePrefix := "Non-interactive: "
const titlePrefix = "Non-interactive: "
var titleSuffix string

if len(prompt) > maxPromptLengthForTitle {
Expand All @@ -141,7 +167,8 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool
}
slog.Info("Created session for non-interactive run", "session_id", sess.ID)

// Automatically approve all permission requests for this non-interactive session
// Automatically approve all permission requests for this non-interactive
// session.
app.Permissions.AutoApproveSession(sess.ID)

type response struct {
Expand All @@ -165,11 +192,19 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool
messageEvents := app.Messages.Subscribe(ctx)
messageReadBytes := make(map[string]int)

defer fmt.Printf(ansi.ResetProgressBar)
defer func() {
_, _ = fmt.Printf(ansi.ResetProgressBar)

// Always print a newline at the end. If output is a TTY this will
// prevent the prompt from overwriting the last line of output.
_, _ = fmt.Fprintln(output)
}()

for {
// HACK: add it again on every iteration so it doesn't get hidden by
// the terminal due to inactivity.
fmt.Printf(ansi.SetIndeterminateProgressBar)
// HACK: Reinitialize the terminal progress bar on every iteration so
// it doesn't get hidden by the terminal due to inactivity.
_, _ = fmt.Printf(ansi.SetIndeterminateProgressBar)
Comment on lines +204 to +206
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this cause the some bugs on iTerm2/SSH that #1329 fixed for interactive?

If so, we might want to address them.

/cc @aymanbagabas

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw this was a present prior to this PR and lives in many places around the codebase.


select {
case result := <-done:
stopSpinner()
Expand All @@ -196,7 +231,7 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool
}

part := content[readBytes:]
fmt.Print(part)
fmt.Fprint(output, part)
messageReadBytes[msg.ID] = len(content)
}

Expand Down
19 changes: 14 additions & 5 deletions internal/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"log/slog"
"os"
"strings"

"github.com/spf13/cobra"
Expand All @@ -18,10 +19,13 @@ The prompt can be provided as arguments or piped from stdin.`,
crush run Explain the use of context in Go

# Pipe input from stdin
echo "What is this code doing?" | crush run
curl https://charm.land | crush run "Summarize this website"

# Run with quiet mode (no spinner)
crush run -q "Generate a README for this project"
# Read from a file
crush run "What is this code doing?" <<< prrr.go

# Run in quiet mode (hide the spinner)
crush run --quiet "Generate a README for this project"
`,
RunE: func(cmd *cobra.Command, args []string) error {
quiet, _ := cmd.Flags().GetBool("quiet")
Expand All @@ -48,8 +52,13 @@ crush run -q "Generate a README for this project"
return fmt.Errorf("no prompt provided")
}

// Run non-interactive flow using the App method
return app.RunNonInteractive(cmd.Context(), prompt, quiet)
// TODO: Make this work when redirected to something other than stdout.
// For example:
// crush run "Do something fancy" > output.txt
// echo "Do something fancy" | crush run > output.txt
//
// TODO: We currently need to press ^c twice to cancel. Fix that.
return app.RunNonInteractive(cmd.Context(), os.Stdout, prompt, quiet)
},
}

Expand Down
Loading
Loading