diff --git a/go.mod b/go.mod index 27ca8d5db..2e33df24c 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,8 @@ require ( mvdan.cc/sh/v3 v3.12.1-0.20250902163504-3cf4fd5717a5 ) +require github.com/coder/acp-go-sdk v0.4.9 // indirect + require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.13.0 // indirect diff --git a/go.sum b/go.sum index f390f506f..341f40c2e 100644 --- a/go.sum +++ b/go.sum @@ -114,6 +114,8 @@ github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8 github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM= github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k= +github.com/coder/acp-go-sdk v0.4.9 h1:F4sKT2up4sMqNYt6yt2L9g4MaE09VPgt3eRqDFnoY5k= +github.com/coder/acp-go-sdk v0.4.9/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= diff --git a/internal/acp/agent.go b/internal/acp/agent.go new file mode 100644 index 000000000..92d21ee3b --- /dev/null +++ b/internal/acp/agent.go @@ -0,0 +1,369 @@ +package acp + +import ( + "context" + "errors" + "fmt" + "github.com/charmbracelet/crush/internal/acp/terminal" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/cwd" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/llm/agent" + "github.com/charmbracelet/crush/internal/permission" + "github.com/coder/acp-go-sdk" + "log/slog" + "strings" + "time" +) + +type Agent struct { + app *app.App + conn *acp.AgentSideConnection + terminals *terminal.Service + sink *agentEventSink + promptDone chan any + client acp.ClientCapabilities + debug bool + yolo bool + dataDir string +} + +var ( + _ acp.Agent = (*Agent)(nil) + _ acp.AgentLoader = (*Agent)(nil) + _ acp.AgentExperimental = (*Agent)(nil) +) + +func NewAgent(debug bool, yolo bool, dataDir string) (*Agent, error) { + return &Agent{ + debug: debug, + yolo: yolo, + dataDir: dataDir, + }, nil +} + +func (a *Agent) SetSessionMode(ctx context.Context, params acp.SetSessionModeRequest) (acp.SetSessionModeResponse, error) { + slog.Info("SetSessionMode") + return acp.SetSessionModeResponse{}, nil +} + +func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) { + slog.Info("SetSessionModel") + return acp.SetSessionModelResponse{}, nil +} + +func (a *Agent) SetAgentConnection(conn *acp.AgentSideConnection) { a.conn = conn } + +func (a *Agent) Initialize(ctx context.Context, params acp.InitializeRequest) (acp.InitializeResponse, error) { + slog.Debug("Initialize", "params", params) + a.client = params.ClientCapabilities + a.terminals = terminal.NewService(a.conn, a.client.Terminal) + + return acp.InitializeResponse{ + ProtocolVersion: acp.ProtocolVersionNumber, + AgentCapabilities: acp.AgentCapabilities{ + LoadSession: false, + McpCapabilities: acp.McpCapabilities{ + Http: false, + Sse: false, + }, + PromptCapabilities: acp.PromptCapabilities{ + EmbeddedContext: true, + Audio: false, + Image: false, + }, + }, + }, nil +} + +func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { + slog.Info("New session requested...") + appInstance, err := a.setupApp(ctx, params) + if err != nil { + return acp.NewSessionResponse{}, err + } + a.app = appInstance + a.sink = newAgentSink(ctx, a) + a.promptDone = make(chan any) + close(a.promptDone) // first prompt may run straight away + + go app.Subscribe[any](appInstance, a.sink) + + s, err := a.app.Sessions.Create(ctx, "New ACP Session") + if err != nil { + return acp.NewSessionResponse{}, err + } + + // TODO: send models/modes + //models := a.app.Config().Models + resp := acp.NewSessionResponse{ + Models: nil, + Modes: nil, + SessionId: acp.SessionId(s.ID), + } + + go func() { + _ = a.NotifySlashCommands(ctx, resp.SessionId, defaultSlashCommands) + }() + + // E.g. we can read remote file like this + //go func() { + // r, _ := a.ReadTextFile(ctx, resp.SessionId, "/Users/andrei/Projects/cache-decorator/src/cache_decorator/storages/memory.py", 0, 0) + //}() + + // E.g. we can write remote file like this + //go func() { + // _ = a.WriteTextFile(ctx, resp.SessionId, "/Users/andrei/Projects/cache-decorator/src/cache_decorator/storages/memory1.py", "Hello here") + //}() + + // E.g. we can call terminal command on client side like this + //go func() { + // if t, err := a.terminals.Create(ctx, resp.SessionId, "ls", terminal.WithArgs("-la")); err == nil { + // _ = t.EmbedInToolCalls(ctx, a.conn) + // } + // + //}() + + return resp, nil +} + +func (a *Agent) NotifySlashCommands(ctx context.Context, sessionId acp.SessionId, commands SlashCommandRegistry) error { + notifyCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := a.conn.SessionUpdate(notifyCtx, acp.SessionNotification{ + SessionId: sessionId, + Update: acp.SessionUpdate{ + AvailableCommandsUpdate: &acp.SessionUpdateAvailableCommandsUpdate{ + AvailableCommands: AvailableCommands(commands), + }, + }, + }); err != nil { + slog.Error("failed to send available-commands update", "error", err) + return err + } + + return nil +} + +func (a *Agent) Authenticate(ctx context.Context, _ acp.AuthenticateRequest) (acp.AuthenticateResponse, error) { + slog.Info("Authenticate") + return acp.AuthenticateResponse{}, nil +} + +func (a *Agent) LoadSession(ctx context.Context, _ acp.LoadSessionRequest) (acp.LoadSessionResponse, error) { + slog.Info("LoadSession") + return acp.LoadSessionResponse{}, nil +} + +func (a *Agent) Cancel(ctx context.Context, params acp.CancelNotification) error { + slog.Info("Cancel") + _, err := a.app.Sessions.Get(ctx, string(params.SessionId)) + if err != nil { + return err + } + + if a.app.CoderAgent != nil { + a.app.CoderAgent.Cancel(string(params.SessionId)) + } + + return nil +} + +func (a *Agent) RunPrompt(ctx context.Context, prompt string, params acp.PromptRequest) error { + sid := string(params.SessionId) + if a.app.CoderAgent.IsSessionBusy(sid) { + slog.Info("Cancel previous prompt.") + a.app.CoderAgent.Cancel(sid) + <-a.promptDone // wait until previous turn canceled + } + + slog.Info("Process a new prompt.") + done, err := a.app.CoderAgent.Run(ctx, string(params.SessionId), prompt) + if err != nil { + slog.Error("Cant run coder agent", "err", err) + return err + } + + a.promptDone = make(chan any) + defer close(a.promptDone) + for { + select { + case result := <-done: + // nil, context.Canceled, or agent.ErrRequestCancelled + return result.Error + } + } +} + +func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) { + var err error + + sid := string(params.SessionId) + if _, err = a.app.Sessions.Get(ctx, sid); err != nil { + err = fmt.Errorf("session %s not found", params.SessionId) + } else { + prompt := Prompt(params.Prompt).String() + a.sink.LastUserPrompt(prompt) + + // FIXME: Add support for different types of content (image, audio and etc) + name, text := parseSlash(prompt) + if name != "" { // slash-command + if cmd := defaultSlashCommands.Get(name); cmd != nil { + slog.Info("Slash command requested", "cmd", name) + err = cmd.Exec(ctx, a, text, params) + } + } else { // normal LLM turn + err = a.RunPrompt(ctx, prompt, params) + } + } + + switch { + case err == nil: + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil + case errors.Is(err, context.Canceled), errors.Is(err, agent.ErrRequestCancelled): + return acp.PromptResponse{StopReason: acp.StopReasonCancelled}, nil + default: + return acp.PromptResponse{}, err // real failure + } +} + +// ReadTextFile allows Agents to read text file contents from the Client’s filesystem, including unsaved changes in the editor. +func (a *Agent) ReadTextFile(ctx context.Context, sessionId acp.SessionId, path string, line int, limit int) (string, error) { + if !a.client.Fs.ReadTextFile { + return "", errors.New("client does not support reading of text files") + } + + var pLine, pLimit *int + if line > 0 { + pLine = acp.Ptr(line) + } + + if limit > 0 { + pLimit = acp.Ptr(limit) + } + + if resp, err := a.conn.ReadTextFile(ctx, acp.ReadTextFileRequest{ + SessionId: sessionId, + Path: path, + Line: pLine, + Limit: pLimit, + }); err != nil { + slog.Error("could not read remote file", "error", err) + return "", err + } else { + return resp.Content, nil + } +} + +// WriteTextFile allows Agents to write or update text files in the Client’s filesystem. +func (a *Agent) WriteTextFile(ctx context.Context, sessionId acp.SessionId, path string, content string) error { + if !a.client.Fs.WriteTextFile { + return errors.New("client does not support writing of text files") + } + + if _, err := a.conn.WriteTextFile(ctx, acp.WriteTextFileRequest{ + SessionId: sessionId, + Path: path, + Content: content, + }); err != nil { + slog.Error("could not write to remote file", "error", err) + return err + } + + return nil +} + +func (a *Agent) setupApp(ctx context.Context, params acp.NewSessionRequest) (*app.App, error) { + cwDir, err := cwd.Resolve(params.Cwd) + if err != nil { + return nil, err + } + + cfg, err := config.Init(cwDir, a.dataDir, a.debug) + if err != nil { + return nil, err + } + + if cfg.Permissions == nil { + cfg.Permissions = &config.Permissions{} + } + cfg.Permissions.SkipRequests = a.yolo + + if err := cwd.CreateDotCrushDir(cfg.Options.DataDirectory); err != nil { + return nil, err + } + + // Connect to DB; this will also run migrations. + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return nil, err + } + + appInstance, err := app.New(ctx, conn, cfg) + if err != nil { + slog.Error("Failed to create app instance", "error", err) + return nil, err + } + + return appInstance, nil +} + +func (a *Agent) RequestPermission(ctx context.Context, req permission.PermissionRequest) { + slog.Info("RequestPermission", "req", req) + payload := acp.RequestPermissionRequest{ + SessionId: acp.SessionId(req.SessionID), + ToolCall: acp.ToolCallUpdate{ + ToolCallId: acp.ToolCallId(req.ToolCallID), + Title: acp.Ptr(req.Description), + Kind: acp.Ptr(acp.ToolKindEdit), + Status: acp.Ptr(acp.ToolCallStatusPending), + Locations: []acp.ToolCallLocation{{Path: req.Path}}, + RawInput: req.Params, + }, Options: []acp.PermissionOption{ + {Kind: acp.PermissionOptionKindAllowOnce, Name: "Allow this change", OptionId: acp.PermissionOptionId("allow")}, + {Kind: acp.PermissionOptionKindRejectOnce, Name: "Skip this change", OptionId: acp.PermissionOptionId("reject")}, + }} + + result, err := a.conn.RequestPermission(ctx, payload) + if err != nil { + slog.Error("error sending permission request", err) + return + } + + if result.Outcome.Selected != nil { + a.app.Permissions.Grant(req) + } else { + a.app.Permissions.Deny(req) + } +} + +// parseSlash parses "input" and returns: +// +// ("", input) – not a slash command +// ("cmd", "rest") – "/cmd rest" +func parseSlash(input string) (cmd, rest string) { + input = strings.TrimSpace(input) + if input == "" || input[0] != '/' { + return "", input + } + after := input[1:] + if i := strings.IndexByte(after, ' '); i == -1 { + return after, "" + } else { + return after[:i], strings.TrimSpace(after[i:]) + } +} + +type Prompt []acp.ContentBlock + +func (p Prompt) String() string { + var sb strings.Builder + for _, b := range p { + if b.Text != nil { + sb.WriteString(b.Text.Text) + } + } + return sb.String() +} diff --git a/internal/acp/server.go b/internal/acp/server.go new file mode 100644 index 000000000..d814881ed --- /dev/null +++ b/internal/acp/server.go @@ -0,0 +1,79 @@ +package acp + +import ( + "context" + "fmt" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/log" + "github.com/coder/acp-go-sdk" + "log/slog" + "os" + "os/signal" + "path/filepath" + "sync" + "syscall" +) + +type Server struct { + ctx context.Context + cancel context.CancelFunc + + //TODO: Only stdio as transport is part of standard, http is still a draft, so only one agent until that + agent *Agent + debug bool + yolo bool + dataDir string +} + +func NewServer(ctx context.Context, debug bool, yolo bool, dataDir string) (*Server, error) { + ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, os.Kill, syscall.SIGTERM) + log.Setup( + filepath.Join(LogsDir(), "logs", fmt.Sprintf("%s.log", config.AppName)), + debug, + ) + + return &Server{ + ctx: ctx, + cancel: cancel, + debug: debug, + yolo: yolo, + dataDir: dataDir, + }, nil +} + +func (s *Server) Run() error { + agent, err := NewAgent(s.debug, s.yolo, s.dataDir) + if err != nil { + return err + } + s.agent = agent + slog.Info("Running in ACP mode") + + conn := acp.NewAgentSideConnection(agent, os.Stdout, os.Stdin) + agent.SetAgentConnection(conn) + conn.SetLogger(slog.Default()) + + select { + case <-conn.Done(): + slog.Debug("peer disconnected, shutting down") + case <-s.ctx.Done(): + slog.Debug("received termination signal, shutting down", "signal", s.ctx.Err()) + } + + return nil +} + +func (s *Server) Shutdown() { + // Graceful shutdown + s.agent = nil +} + +var LogsDir = sync.OnceValue(func() string { + tmp := filepath.Join(os.TempDir(), config.AppName) + if err := os.MkdirAll(tmp, 0755); err != nil { + slog.Error("could not create temp dir", "tmp", tmp) + os.Exit(-1) + } + + return tmp +}) diff --git a/internal/acp/sink.go b/internal/acp/sink.go new file mode 100644 index 000000000..0be47771e --- /dev/null +++ b/internal/acp/sink.go @@ -0,0 +1,71 @@ +package acp + +import ( + "context" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/coder/acp-go-sdk" + "log/slog" +) + +// implementing App's EventSink +type agentEventSink struct { + ctx context.Context + agent *Agent + updatesIter *updateIterator + lastUserPrompt string +} + +var ( + _ app.EventSink[any] = (*agentEventSink)(nil) +) + +func newAgentSink(ctx context.Context, agent *Agent) *agentEventSink { + return &agentEventSink{ + ctx: ctx, + agent: agent, + updatesIter: newUpdatesIterator(), + } +} + +func (sink *agentEventSink) Send(msg any) { + switch ev := msg.(type) { + case pubsub.Event[message.Message]: + sink.handleMessage(ev.Payload) + case pubsub.Event[permission.PermissionRequest]: + sink.handlePermission(ev.Payload) + } +} + +func (sink *agentEventSink) handleToolCall(t message.ToolCall) { + slog.Info("handleToolCall", "tool", t) +} + +func (sink *agentEventSink) handleMessage(m message.Message) { + if m.Role == message.User && sink.lastUserPrompt == m.Content().String() { + return + } + + for update := range sink.updatesIter.next(&m) { + if err := sink.agent.conn.SessionUpdate(sink.ctx, acp.SessionNotification{ + SessionId: acp.SessionId(m.SessionID), + Update: update, + }); err != nil { + slog.Error("session update failed", "error", err) + } + } +} + +func (sink *agentEventSink) handlePermission(req permission.PermissionRequest) { + sink.agent.RequestPermission(sink.ctx, req) +} + +func (sink *agentEventSink) LastUserPrompt(prompt string) { + sink.lastUserPrompt = prompt +} + +func (sink *agentEventSink) Quit() { + // +} diff --git a/internal/acp/slash.go b/internal/acp/slash.go new file mode 100644 index 000000000..fcb9f4bb7 --- /dev/null +++ b/internal/acp/slash.go @@ -0,0 +1,77 @@ +package acp + +import ( + "context" + "fmt" + "github.com/charmbracelet/crush/internal/llm/prompt" + "github.com/coder/acp-go-sdk" +) + +type SlashCommand interface { + Name() string + Help() string + Exec(ctx context.Context, agent *Agent, text string, params acp.PromptRequest) error +} + +type SlashCommandRegistry []SlashCommand + +func (r SlashCommandRegistry) Get(name string) SlashCommand { + for _, cmd := range defaultSlashCommands { + if cmd.Name() == name { + return cmd + } + } + + return nil +} + +// AvailableCommands generates a slice of acp.AvailableCommand from a slice of SlashCommand +func AvailableCommands(commands SlashCommandRegistry) []acp.AvailableCommand { + out := make([]acp.AvailableCommand, 0, len(commands)) + for _, cmd := range commands { + out = append(out, acp.AvailableCommand{ + Name: cmd.Name(), + Input: &acp.AvailableCommandInput{ + &acp.UnstructuredCommandInput{ + Hint: cmd.Help(), + }, + }, + }) + } + return out +} + +var defaultSlashCommands = SlashCommandRegistry{ + yoloCmd{}, + initCmd{}, +} + +type yoloCmd struct{} + +func (yoloCmd) Name() string { return "yolo" } +func (yoloCmd) Help() string { return "Toggle Yolo Mode" } +func (yoloCmd) Exec(ctx context.Context, agent *Agent, text string, params acp.PromptRequest) error { + agent.app.Permissions.SetSkipRequests(!agent.app.Permissions.SkipRequests()) + status := agent.app.Permissions.SkipRequests() + + return agent.conn.SessionUpdate(ctx, acp.SessionNotification{ + SessionId: params.SessionId, + Update: acp.UpdateAgentMessage(acp.ContentBlock{ + Text: &acp.ContentBlockText{ + Text: fmt.Sprintf("YOLO mode is now **%s**.", map[bool]string{ + true: "ON", + false: "OFF", + }[status]), + Type: "text", + }, + }), + }) +} + +type initCmd struct{} + +func (initCmd) Name() string { return "init" } +func (initCmd) Help() string { return "Initialize Project" } +func (initCmd) Exec(ctx context.Context, agent *Agent, text string, params acp.PromptRequest) error { + return agent.RunPrompt(ctx, prompt.Initialize(), params) +} diff --git a/internal/acp/terminal/options.go b/internal/acp/terminal/options.go new file mode 100644 index 000000000..15a6ec843 --- /dev/null +++ b/internal/acp/terminal/options.go @@ -0,0 +1,27 @@ +package terminal + +type createOpts struct { + args []string + env map[string]string + cwd string + byteLim *int +} + +// CreateOption defines a single optional argument for Create. +type CreateOption func(*createOpts) + +func WithArgs(a ...string) CreateOption { + return func(o *createOpts) { o.args = a } +} + +func WithEnv(e map[string]string) CreateOption { + return func(o *createOpts) { o.env = e } +} + +func WithCwd(dir string) CreateOption { + return func(o *createOpts) { o.cwd = dir } +} + +func WithByteLimit(n int) CreateOption { + return func(o *createOpts) { o.byteLim = &n } +} diff --git a/internal/acp/terminal/service.go b/internal/acp/terminal/service.go new file mode 100644 index 000000000..fe41b58f3 --- /dev/null +++ b/internal/acp/terminal/service.go @@ -0,0 +1,104 @@ +package terminal + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/coder/acp-go-sdk" +) + +// Service owns 0…N running terminals for one session. +type Service struct { + conn *acp.AgentSideConnection + termAllowed bool + mu sync.RWMutex + term map[acp.SessionId]map[ID]*Terminal +} + +func NewService(conn *acp.AgentSideConnection, termAllowed bool) *Service { + return &Service{ + conn: conn, + termAllowed: termAllowed, + term: make(map[acp.SessionId]map[ID]*Terminal), + } +} + +// Create launches a new terminal and keeps it in the registry. +func (s *Service) Create(ctx context.Context, sessionID acp.SessionId, cmd string, opts ...CreateOption) (*Terminal, error) { + if !s.termAllowed { + return nil, errors.New("client does not support terminal capability") + } + + // apply defaults + co := &createOpts{} + for _, fn := range opts { + fn(co) + } + + // build env slice + env := make([]acp.EnvVariable, 0, len(co.env)) + for k, v := range co.env { + env = append(env, acp.EnvVariable{Name: k, Value: v}) + } + + t := New(cmd, co.args, env, co.cwd, co.byteLim) + if err := t.Start(ctx, s.conn, sessionID); err != nil { + return nil, err + } + + // register + s.mu.Lock() + if s.term[sessionID] == nil { + s.term[sessionID] = make(map[ID]*Terminal) + } + s.term[sessionID][t.ID] = t + s.mu.Unlock() + return t, nil +} + +// Get returns an existing terminal or error. +func (s *Service) Get(sessionID acp.SessionId, id ID) (*Terminal, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + set, ok := s.term[sessionID] + if !ok { + return nil, fmt.Errorf("session %q has no terminals", sessionID) + } + t, ok := set[id] + if !ok { + return nil, fmt.Errorf("terminal %q not found", id) + } + return t, nil +} + +// Release removes **one** terminal and calls its Release method. +func (s *Service) Release(ctx context.Context, sessionID acp.SessionId, id ID) error { + s.mu.Lock() + t, ok := s.term[sessionID][id] + if !ok { + s.mu.Unlock() + return fmt.Errorf("terminal %q not found", id) + } + delete(s.term[sessionID], id) + if len(s.term[sessionID]) == 0 { + delete(s.term, sessionID) + } + s.mu.Unlock() + + return t.Release(ctx, s.conn) +} + +// ReleaseAll kills every terminal that belongs to a session (handy on session close). +func (s *Service) ReleaseAll(ctx context.Context, sessionID acp.SessionId) { + s.mu.Lock() + set := s.term[sessionID] + delete(s.term, sessionID) + s.mu.Unlock() + + for _, t := range set { + _ = t.Release(ctx, s.conn) + } +} diff --git a/internal/acp/terminal/terminal.go b/internal/acp/terminal/terminal.go new file mode 100644 index 000000000..c06d82545 --- /dev/null +++ b/internal/acp/terminal/terminal.go @@ -0,0 +1,108 @@ +package terminal + +import ( + "context" + "github.com/coder/acp-go-sdk" + "github.com/google/uuid" + "log/slog" +) + +type ID string + +// Terminal wraps one ACP terminal plus local metadata. +type Terminal struct { + ID ID + SessionID acp.SessionId + Cmd string + Args []string + Env []acp.EnvVariable + Cwd string + ByteLim *int +} + +// New creates a Terminal value, but does NOT start it. +func New(cmd string, args []string, env []acp.EnvVariable, cwd string, byteLim *int) *Terminal { + return &Terminal{ + Cmd: cmd, + Args: args, + Env: env, + Cwd: cwd, + ByteLim: byteLim, + } +} + +// Start starts a command in a new terminal +func (t *Terminal) Start(ctx context.Context, conn *acp.AgentSideConnection, sessionID acp.SessionId) error { + req := acp.CreateTerminalRequest{ + SessionId: sessionID, + Command: t.Cmd, + Args: t.Args, + Env: t.Env, + OutputByteLimit: t.ByteLim, + } + if t.Cwd != "" { + req.Cwd = &t.Cwd + } + + resp, err := conn.CreateTerminal(ctx, req) + if err != nil { + return err + } + t.ID = ID(resp.TerminalId) + t.SessionID = sessionID + return nil +} + +// Output retrieves the current terminal output without waiting for the command to complete +func (t *Terminal) Output(ctx context.Context, conn *acp.AgentSideConnection) (acp.TerminalOutputResponse, error) { + return conn.TerminalOutput(ctx, acp.TerminalOutputRequest{ + SessionId: t.SessionID, + TerminalId: string(t.ID), + }) +} + +// WaitForExit returns once the command completes +func (t *Terminal) WaitForExit(ctx context.Context, conn *acp.AgentSideConnection) (acp.WaitForTerminalExitResponse, error) { + return conn.WaitForTerminalExit(ctx, acp.WaitForTerminalExitRequest{ + SessionId: t.SessionID, + TerminalId: string(t.ID), + }) +} + +// Kill terminates a command without releasing the terminal +func (t *Terminal) Kill(ctx context.Context, conn *acp.AgentSideConnection) (acp.KillTerminalCommandResponse, error) { + return conn.KillTerminalCommand(ctx, acp.KillTerminalCommandRequest{ + SessionId: t.SessionID, + TerminalId: string(t.ID), + }) +} + +// Release kills the command if still running and releases all resources +func (t *Terminal) Release(ctx context.Context, conn *acp.AgentSideConnection) error { + _, err := conn.ReleaseTerminal(ctx, acp.ReleaseTerminalRequest{ + SessionId: t.SessionID, + TerminalId: string(t.ID), + }) + if err != nil { + slog.Error("could not release terminal", "err", err) + } + + return err +} + +// EmbedInToolCalls produces the SessionUpdate that advertises the terminal via tool calls +func (t *Terminal) EmbedInToolCalls(ctx context.Context, conn *acp.AgentSideConnection) error { + return conn.SessionUpdate(ctx, acp.SessionNotification{ + SessionId: t.SessionID, + Update: acp.SessionUpdate{ + ToolCall: &acp.SessionUpdateToolCall{ + ToolCallId: acp.ToolCallId("terminal_call_" + uuid.New().String()), + Kind: acp.ToolKindExecute, + Status: acp.ToolCallStatusInProgress, + Content: []acp.ToolCallContent{ + acp.ToolTerminalRef(string(t.ID)), + }, + }, + }, + }) +} diff --git a/internal/acp/tools.go b/internal/acp/tools.go new file mode 100644 index 000000000..77f9b7516 --- /dev/null +++ b/internal/acp/tools.go @@ -0,0 +1,115 @@ +package acp + +import ( + "encoding/json" + "fmt" + "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/message" + "github.com/coder/acp-go-sdk" + "log/slog" + "strings" +) + +type ToolCall message.ToolCall + +func (t ToolCall) Kind() acp.ToolKind { + switch t.Name { + case tools.BashToolName, tools.BashNoOutput: + return acp.ToolKindExecute + case tools.DownloadToolName, tools.FetchToolName: + return acp.ToolKindFetch + case tools.GlobToolName, tools.LSToolName, tools.GrepToolName: + return acp.ToolKindSearch + case tools.EditToolName, tools.MultiEditToolName, tools.WriteToolName: + return acp.ToolKindEdit + case tools.ViewToolName: + return acp.ToolKindRead + } + + return acp.ToolKindOther +} + +func (t ToolCall) StartToolCall() *acp.SessionUpdateToolCall { + result := &acp.SessionUpdateToolCall{ + ToolCallId: acp.ToolCallId(t.ID), + Kind: t.Kind(), + Title: t.Name, + Status: acp.ToolCallStatusPending, + } + + return result +} + +func (t ToolCall) UpdateToolCall() *acp.SessionUpdateToolCallUpdate { + input := map[string]any{} + if err := json.Unmarshal([]byte(t.Input), &input); err != nil { + slog.Warn("Error decoding input data", "err", err) + } + + result := &acp.SessionUpdateToolCallUpdate{ + ToolCallId: acp.ToolCallId(t.ID), + Status: acp.Ptr(acp.ToolCallStatusInProgress), + } + + filePath, _ := input["file_path"].(string) + offset, _ := input["offset"].(int) + limit, _ := input["limit"].(int) + oldText, _ := input["old_string"].(string) + newText, _ := input["new_string"].(string) + content, _ := input["content"].(string) + + var locations []acp.ToolCallLocation + if filePath != "" { + locations = append(locations, acp.ToolCallLocation{ + Path: filePath, + Line: acp.Ptr(offset), + }) + } + + switch t.Name { + case tools.EditToolName: + { + var title strings.Builder + title.WriteString("Edit ") + title.WriteString(filePath) + result.Title = acp.Ptr(title.String()) + result.Content = []acp.ToolCallContent{acp.ToolDiffContent(filePath, newText, oldText)} + } + case tools.WriteToolName: + { + var title strings.Builder + title.WriteString("Edit ") + title.WriteString(filePath) + result.Title = acp.Ptr(title.String()) + + if filePath != "" { + result.Content = []acp.ToolCallContent{acp.ToolDiffContent(filePath, newText, "")} + } else { + result.Content = []acp.ToolCallContent{ + acp.ToolContent(acp.ContentBlock{ + Text: acp.Ptr(acp.ContentBlockText{Text: content}), + }), + } + } + } + case tools.ViewToolName: + { + var title strings.Builder + title.WriteString("Read ") + title.WriteString(filePath) + switch { + case limit > 0: + fmt.Fprintf(&title, " (%d - %d)", offset, offset+limit) + case offset > 0: + fmt.Fprintf(&title, " (from line %d)", offset) + default: + title.WriteString(" File") + } + + result.Title = acp.Ptr(title.String()) + result.Locations = locations + } + } + + return result +} diff --git a/internal/acp/updates.go b/internal/acp/updates.go new file mode 100644 index 000000000..4644e8e6b --- /dev/null +++ b/internal/acp/updates.go @@ -0,0 +1,136 @@ +package acp + +import ( + "github.com/charmbracelet/crush/internal/message" + "github.com/coder/acp-go-sdk" + "iter" + "log/slog" +) + +type updateIterator struct { + lastT map[message.MessageRole]int + lastR int +} + +func newUpdatesIterator() *updateIterator { + i := &updateIterator{} + i.reset() + return i +} + +func (it *updateIterator) next(msg *message.Message) iter.Seq[acp.SessionUpdate] { + return func(yield func(acp.SessionUpdate) bool) { + for _, p := range msg.Parts { + if n := it.getUpdate(msg.Role, p); n != (acp.SessionUpdate{}) && !yield(n) { + return + } + } + } +} + +// FIXME: Add support for different types of content (image, audio and etc) +func (it *updateIterator) getContentBlock(role message.MessageRole, part message.ContentPart) (result acp.ContentBlock) { + switch v := part.(type) { + case message.TextContent: + lastLen := it.lastT[role] + nextLen := len(v.Text) + if nextLen <= lastLen { + return + } + delta := v.Text[lastLen:] + it.lastT[role] = nextLen + if delta != "" { + return acp.ContentBlock{ + Text: &acp.ContentBlockText{ + Text: delta, + }, + } + } + + case message.ReasoningContent: + if len(v.Thinking) <= it.lastR { + return + } + + delta := v.Thinking[it.lastR:] + it.lastR = len(v.Thinking) + if delta != "" { + return acp.ContentBlock{ + Text: &acp.ContentBlockText{ + Text: delta, + }, + } + } + + case message.BinaryContent: + case message.ImageURLContent: + case message.Finish: + it.reset() + } + + return +} + +func (it *updateIterator) reset() { + it.lastT = make(map[message.MessageRole]int) + it.lastR = 0 +} + +func (it *updateIterator) getUpdate(role message.MessageRole, part message.ContentPart) (result acp.SessionUpdate) { + content := it.getContentBlock(role, part) + hasContent := content != (acp.ContentBlock{}) + + switch t := part.(type) { + case message.ToolCall: + { + slog.Info("ToolCall", "t", t) + tool := ToolCall(t) + if !t.Finished { + return acp.SessionUpdate{ToolCall: tool.StartToolCall()} + } + return acp.SessionUpdate{ToolCallUpdate: tool.UpdateToolCall()} + } + case message.ToolResult: + { + slog.Info("ToolResult") + status := acp.ToolCallStatusCompleted + if t.IsError { + status = acp.ToolCallStatusFailed + } + + // FIXME: refactor it in the same way as ToolCall + // TODO: add support for images? + return acp.UpdateToolCall( + acp.ToolCallId(t.ToolCallID), + acp.WithUpdateStatus(status), + acp.WithUpdateContent([]acp.ToolCallContent{ + acp.ToolContent(acp.ContentBlock{ + Text: &acp.ContentBlockText{ + Text: t.Content, + Meta: t.Metadata, + }, + }), + }), + ) + } + case message.ReasoningContent: + if hasContent { + return acp.UpdateAgentThought(content) + } + default: + { + switch role { + case message.Assistant: + if hasContent { + return acp.UpdateAgentMessage(content) + } + case message.User: + if hasContent { + return acp.UpdateUserMessage(content) + } + } + } + } + + return +} diff --git a/internal/app/app.go b/internal/app/app.go index 8f305f765..79259441a 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,7 +9,6 @@ import ( "sync" "time" - tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" @@ -25,6 +24,11 @@ import ( "github.com/charmbracelet/x/ansi" ) +type EventSink[T any] interface { + Send(e T) + Quit() +} + type App struct { Sessions session.Service Messages message.Service @@ -39,8 +43,8 @@ type App struct { serviceEventsWG *sync.WaitGroup eventsCtx context.Context - events chan tea.Msg - tuiWG *sync.WaitGroup + events chan any + consumerWg *sync.WaitGroup // global context and cleanup functions globalCtx context.Context @@ -70,9 +74,9 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { config: cfg, - events: make(chan tea.Msg, 100), + events: make(chan any, 100), serviceEventsWG: &sync.WaitGroup{}, - tuiWG: &sync.WaitGroup{}, + consumerWg: &sync.WaitGroup{}, } app.setupEvents() @@ -232,7 +236,7 @@ func setupSubscriber[T any]( wg *sync.WaitGroup, name string, subscriber func(context.Context) <-chan pubsub.Event[T], - outputCh chan<- tea.Msg, + outputCh chan<- any, ) { wg.Go(func() { subCh := subscriber(ctx) @@ -243,9 +247,8 @@ func setupSubscriber[T any]( slog.Debug("subscription channel closed", "name", name) return } - var msg tea.Msg = event select { - case outputCh <- msg: + case outputCh <- event: case <-time.After(2 * time.Second): slog.Warn("message dropped due to slow consumer", "name", name) case <-ctx.Done(): @@ -287,34 +290,34 @@ func (app *App) InitCoderAgent() error { return nil } -// Subscribe sends events to the TUI as tea.Msgs. -func (app *App) Subscribe(program *tea.Program) { +// Subscribe sends events to the EventSink as M +func Subscribe[M any](app *App, target EventSink[M]) { defer log.RecoverPanic("app.Subscribe", func() { - slog.Info("TUI subscription panic: attempting graceful shutdown") - program.Quit() + slog.Info("Consumer subscription panic: attempting graceful shutdown") + target.Quit() }) - app.tuiWG.Add(1) - tuiCtx, tuiCancel := context.WithCancel(app.globalCtx) + app.consumerWg.Add(1) + consumerCtx, consumerCancel := context.WithCancel(app.globalCtx) app.cleanupFuncs = append(app.cleanupFuncs, func() error { - slog.Debug("Cancelling TUI message handler") - tuiCancel() - app.tuiWG.Wait() + slog.Debug("Cancelling Consumer message handler") + consumerCancel() + app.consumerWg.Wait() return nil }) - defer app.tuiWG.Done() + defer app.consumerWg.Done() for { select { - case <-tuiCtx.Done(): - slog.Debug("TUI message handler shutting down") + case <-consumerCtx.Done(): + slog.Debug("Consumer message handler shutting down") return case msg, ok := <-app.events: if !ok { - slog.Debug("TUI message channel closed") + slog.Debug("Consumer message channel closed") return } - program.Send(msg) + target.Send(msg.(M)) } } } diff --git a/internal/cmd/acp.go b/internal/cmd/acp.go new file mode 100644 index 000000000..9899d5a3d --- /dev/null +++ b/internal/cmd/acp.go @@ -0,0 +1,39 @@ +package cmd + +import ( + "github.com/charmbracelet/crush/internal/acp" + "github.com/charmbracelet/crush/internal/event" + "github.com/spf13/cobra" +) + +var acpCmd = &cobra.Command{ + Use: "acp", + Short: "Start the crush in ACP mode", + Long: `Allows crush to be connected with ACP compliant clients.`, + + RunE: func(cmd *cobra.Command, args []string) error { + debug, _ := cmd.Flags().GetBool("debug") + yolo, _ := cmd.Flags().GetBool("yolo") + dataDir, _ := cmd.Flags().GetString("data-dir") + + acpServer, err := acp.NewServer(cmd.Context(), debug, yolo, dataDir) + if err != nil { + return err + } + defer acpServer.Shutdown() + + if shouldEnableMetrics(false) { + event.Init() + } + + event.AppInitialized() + defer event.AppExited() + + if err = acpServer.Run(); err != nil { + event.Error(err) + return err + } + + return nil + }, +} diff --git a/internal/cmd/logs.go b/internal/cmd/logs.go index 437208318..317f857e6 100644 --- a/internal/cmd/logs.go +++ b/internal/cmd/logs.go @@ -11,6 +11,7 @@ import ( "time" "github.com/charmbracelet/colorprofile" + "github.com/charmbracelet/crush/internal/acp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/log/v2" "github.com/charmbracelet/x/term" @@ -45,17 +46,29 @@ var logsCmd = &cobra.Command{ return fmt.Errorf("failed to get tail flag: %v", err) } + acpLogs, err := cmd.Flags().GetBool("acp") + if err != nil { + return fmt.Errorf("failed to get acp flag: %v", err) + } + log.SetLevel(log.DebugLevel) log.SetOutput(os.Stdout) if !term.IsTerminal(os.Stdout.Fd()) { log.SetColorProfile(colorprofile.NoTTY) } - cfg, err := config.Load(cwd, dataDir, false) - if err != nil { - return fmt.Errorf("failed to load configuration: %v", err) + var logsFile string + if acpLogs { + logsFile = acp.LogsDir() + } else { + cfg, err := config.Load(cwd, dataDir, false) + if err != nil { + return fmt.Errorf("failed to load configuration: %v", err) + } + logsFile = cfg.Options.DataDirectory } - logsFile := filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log") + + logsFile = filepath.Join(logsFile, "logs", fmt.Sprintf("%s.log", config.AppName)) _, err = os.Stat(logsFile) if os.IsNotExist(err) { log.Warn("Looks like you are not in a crush project. No logs found.") @@ -73,6 +86,7 @@ var logsCmd = &cobra.Command{ func init() { logsCmd.Flags().BoolP("follow", "f", false, "Follow log output") logsCmd.Flags().IntP("tail", "t", defaultTailLines, "Show only the last N lines default: 1000 for performance") + logsCmd.Flags().BoolP("acp", "", false, "Show logs for ACP server") } func followLogs(ctx context.Context, logsFile string, tailLines int) error { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 005f2e86f..2eda09923 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -4,17 +4,16 @@ import ( "bytes" "context" "errors" - "fmt" "io" "log/slog" "os" - "path/filepath" "strconv" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/colorprofile" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/cwd" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/tui" @@ -40,6 +39,7 @@ func init() { updateProvidersCmd, logsCmd, schemaCmd, + acpCmd, ) } @@ -72,21 +72,21 @@ crush run "Explain the use of context in Go" crush -y `, RunE: func(cmd *cobra.Command, args []string) error { - app, err := setupApp(cmd) + appInstance, err := setupApp(cmd) if err != nil { return err } - defer app.Shutdown() + defer appInstance.Shutdown() event.AppInitialized() // Set up the TUI. program := tea.NewProgram( - tui.New(app), + tui.New(appInstance), tea.WithContext(cmd.Context()), tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state - go app.Subscribe(program) + go app.Subscribe[tea.Msg](appInstance, program) if _, err := program.Run(); err != nil { event.Error(err) @@ -151,12 +151,13 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { dataDir, _ := cmd.Flags().GetString("data-dir") ctx := cmd.Context() - cwd, err := ResolveCwd(cmd) + cwDir, _ := cmd.Flags().GetString("cwd") + cwDir, err := cwd.Resolve(cwDir) if err != nil { return nil, err } - cfg, err := config.Init(cwd, dataDir, debug) + cfg, err := config.Init(cwDir, dataDir, debug) if err != nil { return nil, err } @@ -166,7 +167,7 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { } cfg.Permissions.SkipRequests = yolo - if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil { + if err := cwd.CreateDotCrushDir(cfg.Options.DataDirectory); err != nil { return nil, err } @@ -182,21 +183,21 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return nil, err } - if shouldEnableMetrics() { + if shouldEnableMetrics(true) { event.Init() } return appInstance, nil } -func shouldEnableMetrics() bool { +func shouldEnableMetrics(useConfig bool) bool { if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v { return false } if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v { return false } - if config.Get().Options.DisableMetrics { + if useConfig && config.Get().Options.DisableMetrics { return false } return true @@ -219,34 +220,3 @@ func MaybePrependStdin(prompt string) (string, error) { } return string(bts) + "\n\n" + prompt, nil } - -func ResolveCwd(cmd *cobra.Command) (string, error) { - cwd, _ := cmd.Flags().GetString("cwd") - if cwd != "" { - err := os.Chdir(cwd) - if err != nil { - return "", fmt.Errorf("failed to change directory: %v", err) - } - return cwd, nil - } - cwd, err := os.Getwd() - if err != nil { - return "", fmt.Errorf("failed to get current working directory: %v", err) - } - return cwd, nil -} - -func createDotCrushDir(dir string) error { - if err := os.MkdirAll(dir, 0o700); err != nil { - return fmt.Errorf("failed to create data directory: %q %w", dir, err) - } - - gitIgnorePath := filepath.Join(dir, ".gitignore") - if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) { - if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil { - return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err) - } - } - - return nil -} diff --git a/internal/config/config.go b/internal/config/config.go index ff948b874..811458ac9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,7 +18,7 @@ import ( ) const ( - appName = "crush" + AppName = "crush" defaultDataDirectory = ".crush" ) diff --git a/internal/config/load.go b/internal/config/load.go index a219b7d1c..c7b209dff 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -60,7 +60,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { // Setup logs log.Setup( - filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)), + filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", AppName)), cfg.Options.Debug, ) @@ -536,7 +536,7 @@ func lookupConfigs(cwd string) []string { GlobalConfigData(), } - configNames := []string{appName + ".json", "." + appName + ".json"} + configNames := []string{AppName + ".json", "." + AppName + ".json"} foundConfigs, err := fsext.Lookup(cwd, configNames...) if err != nil { @@ -621,7 +621,7 @@ func hasAWSCredentials(env env.Env) bool { func GlobalConfig() string { xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") if xdgConfigHome != "" { - return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(xdgConfigHome, AppName, fmt.Sprintf("%s.json", AppName)) } // return the path to the main config directory @@ -632,10 +632,10 @@ func GlobalConfig() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(localAppData, AppName, fmt.Sprintf("%s.json", AppName)) } - return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(home.Dir(), ".config", AppName, fmt.Sprintf("%s.json", AppName)) } // GlobalConfigData returns the path to the main data directory for the application. @@ -643,7 +643,7 @@ func GlobalConfig() string { func GlobalConfigData() string { xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(xdgDataHome, AppName, fmt.Sprintf("%s.json", AppName)) } // return the path to the main data directory @@ -654,10 +654,10 @@ func GlobalConfigData() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(localAppData, AppName, fmt.Sprintf("%s.json", AppName)) } - return filepath.Join(home.Dir(), ".local", "share", appName, fmt.Sprintf("%s.json", appName)) + return filepath.Join(home.Dir(), ".local", "share", AppName, fmt.Sprintf("%s.json", AppName)) } func assignIfNil[T any](ptr **T, val T) { diff --git a/internal/config/provider.go b/internal/config/provider.go index 108d6a667..80d71d853 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -31,7 +31,7 @@ var ( func providerCacheFileData() string { xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { - return filepath.Join(xdgDataHome, appName, "providers.json") + return filepath.Join(xdgDataHome, AppName, "providers.json") } // return the path to the main data directory @@ -42,10 +42,10 @@ func providerCacheFileData() string { if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } - return filepath.Join(localAppData, appName, "providers.json") + return filepath.Join(localAppData, AppName, "providers.json") } - return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json") + return filepath.Join(home.Dir(), ".local", "share", AppName, "providers.json") } func saveProvidersInCache(path string, providers []catwalk.Provider) error { diff --git a/internal/cwd/cwd.go b/internal/cwd/cwd.go new file mode 100644 index 000000000..82b6f7c9e --- /dev/null +++ b/internal/cwd/cwd.go @@ -0,0 +1,37 @@ +package cwd + +import ( + "fmt" + "os" + "path/filepath" +) + +func Resolve(cwd string) (string, error) { + if cwd != "" { + err := os.Chdir(cwd) + if err != nil { + return "", fmt.Errorf("failed to change directory: %v", err) + } + return cwd, nil + } + cwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("failed to get current working directory: %v", err) + } + return cwd, nil +} + +func CreateDotCrushDir(dir string) error { + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("failed to create data directory: %q %w", dir, err) + } + + gitIgnorePath := filepath.Join(dir, ".gitignore") + if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) { + if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil { + return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err) + } + } + + return nil +}