diff --git a/.commandcode/taste/taste.md b/.commandcode/taste/taste.md new file mode 100644 index 0000000..a6b2605 --- /dev/null +++ b/.commandcode/taste/taste.md @@ -0,0 +1,17 @@ +# Taste (Continuously Learned by [CommandCode][cmd]) + +[cmd]: https://commandcode.ai/ + +# git +- Use short conventional commit messages (type(scope): description) — describe only what changed, no explanations or AI slop. Confidence: 0.70 +- Do not add Co-authored-by or other attribution lines to commit messages. Confidence: 0.75 + +# audit +- Always run the /slop audit on changed code before committing, checking for AI-generated slop patterns (obvious comments, TODO placeholders, identity functions, robotic naming). Confidence: 0.85 + +# code-style +- Do not write AI-generated slop code — avoid obvious comments that restate the code, redundant doc strings, and architectural trivia in comments. Confidence: 0.85 +- Always follow Google Go style conventions. Confidence: 0.85 + +# workflow +- Always run markdownlint-cli2 after modifying markdown files. Confidence: 0.90 diff --git a/cmd/completion.go b/cmd/completion.go index f17071e..37b629f 100644 --- a/cmd/completion.go +++ b/cmd/completion.go @@ -185,7 +185,9 @@ PowerShell: if closeOut { if f, ok := out.(*os.File); ok { - f.Close() + if err := f.Close(); err != nil { + cmd.Printf("Error closing completion file: %v\n", err) + } } } }, diff --git a/cmd/context.go b/cmd/context.go index e18b806..789146b 100644 --- a/cmd/context.go +++ b/cmd/context.go @@ -6,21 +6,26 @@ import ( "github.com/dkmnx/kairo/internal/config" "github.com/dkmnx/kairo/internal/constants" + "github.com/dkmnx/kairo/internal/crypto" "github.com/spf13/cobra" ) type cliContextKey struct{} +// ConfigDirResolver resolves the default configuration directory. +type ConfigDirResolver func() (string, error) + // CLIContext holds shared CLI state: config directory, verbosity, config cache, // root context, and external dependencies. It is safe for concurrent use. type CLIContext struct { - configDir string - configDirMu sync.RWMutex - verbose bool - verboseMu sync.RWMutex - configCache *config.ConfigCache - rootCtx context.Context - deps *Deps + configDir string + configDirMu sync.RWMutex + configDirResolver ConfigDirResolver + verbose bool + verboseMu sync.RWMutex + configCache *config.ConfigCache + rootCtx context.Context + deps *Deps defaultProviderExplicit bool } @@ -28,9 +33,10 @@ type CLIContext struct { // NewCLIContext creates a CLIContext with default settings. func NewCLIContext() *CLIContext { return &CLIContext{ - configCache: config.NewConfigCache(constants.ConfigCacheTTL), - rootCtx: context.Background(), - deps: NewDeps(), + configDirResolver: config.DefaultConfigDir, + configCache: config.NewConfigCache(constants.ConfigCacheTTL), + rootCtx: context.Background(), + deps: NewDeps(), } } @@ -43,7 +49,7 @@ func (c *CLIContext) ConfigDir() string { return c.configDir } - dir, err := config.ConfigDir() + dir, err := c.configDirResolver() if err != nil { return "" } @@ -51,6 +57,14 @@ func (c *CLIContext) ConfigDir() string { return dir } +// SetConfigDirResolver sets the function used to locate the config directory. +func (c *CLIContext) SetConfigDirResolver(r ConfigDirResolver) { + c.configDirMu.Lock() + defer c.configDirMu.Unlock() + + c.configDirResolver = r +} + // SetConfigDir overrides the configuration directory. func (c *CLIContext) SetConfigDir(dir string) { c.configDirMu.Lock() @@ -90,6 +104,11 @@ func (c *CLIContext) Deps() *Deps { return c.deps } +// Crypto returns the crypto service for this CLI session. +func (c *CLIContext) Crypto() crypto.Service { + return c.deps.Crypto +} + // SetDeps replaces the external dependencies. For use in tests. func (c *CLIContext) SetDeps(d *Deps) { c.deps = d diff --git a/cmd/coverage_test.go b/cmd/coverage_test.go index 37f20d9..a9fe424 100644 --- a/cmd/coverage_test.go +++ b/cmd/coverage_test.go @@ -1,12 +1,15 @@ package cmd import ( + "bytes" + stderrors "errors" "os" "strings" "testing" "github.com/dkmnx/kairo/internal/config" "github.com/dkmnx/kairo/internal/providers" + "github.com/spf13/cobra" ) func TestHandleSecretsError(t *testing.T) { @@ -235,3 +238,17 @@ func TestProviderDefinition(t *testing.T) { t.Errorf("expected 'custom-provider', got %q", def.Name) } } + +func TestHandleConfigErrorNonBinary(t *testing.T) { + cmd := &cobra.Command{} + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetErr(buf) + + handleConfigError(cmd, stderrors.New("test error")) + + out := buf.String() + if !strings.Contains(out, "Error loading config") { + t.Errorf("expected error message, got: %s", out) + } +} diff --git a/cmd/delete.go b/cmd/delete.go index 8311f2a..7726745 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -2,9 +2,7 @@ package cmd import ( "context" - stderrors "errors" "fmt" - "io/fs" "os" "path/filepath" "strings" @@ -25,22 +23,12 @@ var deleteCmd = &cobra.Command{ Long: "Remove a provider from Kairo. If no provider is specified, shows an interactive list of configured providers.", Run: func(cmd *cobra.Command, args []string) { cliCtx := CLIContextFromCmd(cmd) - dir := requireConfigDir(cmd) - if dir == "" { - return - } - - cfg, err := config.LoadConfig(cliCtx.RootCtx(), dir) - if err != nil { - if stderrors.Is(err, fs.ErrNotExist) { - printNoProvidersMessage() - - return - } - handleConfigError(cmd, err) + cfg, err := loadConfigOrExit(cmd) + if err != nil || cfg == nil { return } + dir := cliCtx.ConfigDir() var target string if len(args) == 0 { @@ -113,7 +101,7 @@ var deleteCmd = &cobra.Command{ secretsPath := filepath.Join(dir, constants.SecretsFileName) keyPath := filepath.Join(dir, constants.KeyFileName) - if err := deleteProviderSecrets(cliCtx.RootCtx(), secretsPath, keyPath, target); err != nil { + if err := deleteProviderSecrets(cliCtx.RootCtx(), cliCtx.Crypto(), secretsPath, keyPath, target); err != nil { tap.Cancel(fmt.Sprintf("Failed to clean up secrets for '%s': %v", target, err)) return @@ -123,8 +111,8 @@ var deleteCmd = &cobra.Command{ }, } -func deleteProviderSecrets(ctx context.Context, secretsPath, keyPath, providerName string) error { - existingSecrets, err := crypto.DecryptSecretsBytes(ctx, secretsPath, keyPath) +func deleteProviderSecrets(ctx context.Context, svc crypto.Service, secretsPath, keyPath, providerName string) error { + existingSecrets, err := svc.DecryptSecretsBytes(ctx, secretsPath, keyPath) if err != nil { return errors.WrapError(errors.CryptoError, "failed to decrypt secrets for cleanup", err). @@ -157,7 +145,7 @@ func deleteProviderSecrets(ctx context.Context, secretsPath, keyPath, providerNa return nil } - if err := crypto.EncryptSecrets(ctx, secretsPath, keyPath, secretsContent); err != nil { + if err := svc.EncryptSecrets(ctx, secretsPath, keyPath, secretsContent); err != nil { return errors.WrapError(errors.CryptoError, "could not update secrets", err). WithContext("path", secretsPath) diff --git a/cmd/delete_test.go b/cmd/delete_test.go index a1785c5..7054f34 100644 --- a/cmd/delete_test.go +++ b/cmd/delete_test.go @@ -112,7 +112,7 @@ func TestDeleteCmdDeletesProviderSecrets(t *testing.T) { t.Fatalf("EncryptSecrets() error = %v", err) } - result, err := LoadSecrets(context.Background(), tmpDir) + result, err := LoadSecrets(NewCLIContext(), tmpDir) if err != nil { t.Fatalf("LoadSecrets() error = %v", err) } @@ -134,7 +134,7 @@ func TestDeleteCmdDeletesProviderSecrets(t *testing.T) { t.Fatalf("EncryptSecrets() error = %v", err) } - result, err = LoadSecrets(context.Background(), tmpDir) + result, err = LoadSecrets(NewCLIContext(), tmpDir) if err != nil { t.Fatalf("LoadSecrets() error = %v", err) } @@ -189,7 +189,7 @@ func TestDeleteProviderSecretsReturnsErrorOnBadKey(t *testing.T) { secretsPath := filepath.Join(tmpDir, constants.SecretsFileName) keyPath := filepath.Join(tmpDir, "nonexistent.key") - err := deleteProviderSecrets(context.Background(), secretsPath, keyPath, "testprovider") + err := deleteProviderSecrets(context.Background(), NewCLIContext().Crypto(), secretsPath, keyPath, "testprovider") if err == nil { t.Fatal("deleteProviderSecrets should return error when decryption fails") } @@ -217,7 +217,7 @@ func TestDeleteProviderSecretsPreservesMalformedLines(t *testing.T) { t.Fatalf("EncryptSecrets() error = %v", err) } - if err := deleteProviderSecrets(context.Background(), secretsPath, keyPath, "PROVIDER_TO_DELETE"); err != nil { + if err := deleteProviderSecrets(context.Background(), NewCLIContext().Crypto(), secretsPath, keyPath, "PROVIDER_TO_DELETE"); err != nil { t.Fatalf("deleteProviderSecrets() error = %v", err) } diff --git a/cmd/deps.go b/cmd/deps.go index 76e5ad3..006bc69 100644 --- a/cmd/deps.go +++ b/cmd/deps.go @@ -5,6 +5,7 @@ import ( "os" "os/exec" + "github.com/dkmnx/kairo/internal/crypto" "github.com/dkmnx/kairo/internal/ui" "github.com/dkmnx/kairo/internal/update" "github.com/dkmnx/kairo/internal/wrapper" @@ -65,5 +66,6 @@ func NewDeps() *Deps { Process: osProcessRunner{}, Wrapper: prodWrapperService{}, Update: &prodUpdateService{client: update.NewClient()}, + Crypto: crypto.DefaultService{}, } } diff --git a/cmd/execution_env.go b/cmd/execution_env.go index cd64649..c777e45 100644 --- a/cmd/execution_env.go +++ b/cmd/execution_env.go @@ -84,7 +84,7 @@ func BuildProviderEnv( ) (EnvBuildResult, error) { builtIn := BuildBuiltInEnvVars(provider) - secretsResult, err := LoadSecrets(cliCtx.RootCtx(), configDir) + secretsResult, err := LoadSecrets(cliCtx, configDir) if err != nil { if RequiresAPIKey(providerName) { return EnvBuildResult{}, err diff --git a/cmd/execution_env_test.go b/cmd/execution_env_test.go index bdacb17..d6e986c 100644 --- a/cmd/execution_env_test.go +++ b/cmd/execution_env_test.go @@ -223,3 +223,45 @@ func TestPiAPIKeyEnvVarMapping(t *testing.T) { }) } } + +func TestHarnessAPIKeyEnvVar(t *testing.T) { + tests := []struct { + provider string + want string + }{ + {"zai", "ZAI_API_KEY"}, + {"minimax", "MINIMAX_API_KEY"}, + {"anthropic", "ANTHROPIC_API_KEY"}, + } + + for _, tt := range tests { + t.Run(tt.provider, func(t *testing.T) { + got := HarnessAPIKeyEnvVar(tt.provider) + if got != tt.want { + t.Errorf("HarnessAPIKeyEnvVar(%q) = %q, want %q", tt.provider, got, tt.want) + } + }) + } +} + +func TestYoloModeFlag(t *testing.T) { + tests := []struct { + name string + harness string + want string + }{ + {"claude", harnessClaude, "--dangerously-skip-permissions"}, + {"qwen", harnessQwen, "--yolo"}, + {"pi", harnessPi, ""}, + {"crush", harnessCrush, "--yolo"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := yoloModeFlag(tt.harness) + if got != tt.want { + t.Errorf("yoloModeFlag(%q) = %q, want %q", tt.harness, got, tt.want) + } + }) + } +} diff --git a/cmd/execution_error.go b/cmd/execution_error.go index b9bea3d..02af3e0 100644 --- a/cmd/execution_error.go +++ b/cmd/execution_error.go @@ -7,25 +7,15 @@ import ( "github.com/dkmnx/kairo/internal/constants" kairoerrors "github.com/dkmnx/kairo/internal/errors" + "github.com/dkmnx/kairo/internal/harness" "github.com/dkmnx/kairo/internal/ui" "github.com/spf13/cobra" "gopkg.in/yaml.v3" ) -const claudeYoloFlag = "--dangerously-skip-permissions" -const qwenYoloFlag = "--yolo" - -func yoloModeFlag(harness string) string { - if harness == harnessQwen || harness == harnessCrush { - return qwenYoloFlag - } - if harness == harnessPi { - return "" - } - - return claudeYoloFlag +func yoloModeFlag(h string) string { + return harness.YoloFlag(h) } - func handleConfigError(cmd *cobra.Command, err error) { if isBinaryOutdatedError(err) { promptUpgrade(cmd, err) diff --git a/cmd/execution_harness.go b/cmd/execution_harness.go index 697dc3f..5f5021a 100644 --- a/cmd/execution_harness.go +++ b/cmd/execution_harness.go @@ -8,6 +8,8 @@ import ( "github.com/dkmnx/kairo/internal/config" kairoerrors "github.com/dkmnx/kairo/internal/errors" + "github.com/dkmnx/kairo/internal/execution" + "github.com/dkmnx/kairo/internal/harness" "github.com/dkmnx/kairo/internal/ui" "github.com/dkmnx/kairo/internal/version" "github.com/dkmnx/kairo/internal/wrapper" @@ -25,15 +27,11 @@ type HarnessRun struct { Harness string } -func qwenAuthArgs(model string) []string { - return []string{"--auth-type", "anthropic", "--model", model} -} - func executePi(cfg ExecutionConfig) error { cliArgs := cfg.HarnessArgs if cfg.Yolo { - flag := yoloModeFlag(cfg.HarnessToUse) + flag := harness.YoloFlag(cfg.HarnessToUse) if flag != "" { cliArgs = append([]string{flag}, cliArgs...) } @@ -59,10 +57,9 @@ func executePi(cfg ExecutionConfig) error { Harness: cfg.HarnessToUse, }) - ctx, cancel := context.WithCancel(CLIContextFromCmd(cfg.Cmd).RootCtx()) + ctx, cancel, stopSig := execution.StartSession(CLIContextFromCmd(cfg.Cmd).RootCtx()) defer cancel() - stopSignalHandler := setupSignalHandler(cancel) - defer stopSignalHandler() + defer stopSig() execCmd := cfg.Deps.Process.ExecCommandContext(ctx, piPath, cliArgs...) execCmd.Env = cfg.ProviderEnv @@ -131,10 +128,9 @@ func executeWithAuth(cfg ExecutionConfig) { } func executeWrapperWithAuth(cfg ExecutionConfig) { - ctx, cancel := context.WithCancel(CLIContextFromCmd(cfg.Cmd).RootCtx()) + ctx, cancel, stopSig := execution.StartSession(CLIContextFromCmd(cfg.Cmd).RootCtx()) defer cancel() - stopSignalHandler := setupSignalHandler(cancel) - defer stopSignalHandler() + defer stopSig() authDir, err := cfg.Deps.Wrapper.CreateTempAuthDir() if err != nil { @@ -146,7 +142,9 @@ func executeWrapperWithAuth(cfg ExecutionConfig) { var cleanupOnce sync.Once cleanup := func() { cleanupOnce.Do(func() { - _ = os.RemoveAll(authDir) + if err := os.RemoveAll(authDir); err != nil { + cfg.Cmd.Printf("Error cleaning up auth directory: %v\n", err) + } }) } defer cleanup() @@ -160,10 +158,10 @@ func executeWrapperWithAuth(cfg ExecutionConfig) { cliArgs := cfg.HarnessArgs if cfg.Yolo { - cliArgs = append([]string{yoloModeFlag(cfg.HarnessToUse)}, cliArgs...) + cliArgs = append([]string{harness.YoloFlag(cfg.HarnessToUse)}, cliArgs...) } - displayName, envVarName, extraArgs := harnessDispatch(cfg.HarnessToUse, cfg.ProviderName, cfg.Provider.Model) + displayName, envVarName, extraArgs := harness.Dispatch(cfg.HarnessToUse, cfg.ProviderName, cfg.Provider.Model) cliArgs = append(extraArgs, cliArgs...) run := HarnessRun{ @@ -183,21 +181,6 @@ func executeWrapperWithAuth(cfg ExecutionConfig) { } } -// harnessDispatch returns the display name, environment variable name, and any -// extra CLI arguments for the given harness. -func harnessDispatch(harness, providerName, model string) (displayName, envVarName string, extraArgs []string) { - switch harness { - case harnessQwen: - return "Qwen", "ANTHROPIC_API_KEY", qwenAuthArgs(model) - case harnessCrush: - return "Crush", HarnessAPIKeyEnvVar(providerName), nil - case harnessPi: - return "Pi", "", nil - default: - return "Claude", "", nil - } -} - func executeWithoutAuth(cfg ExecutionConfig) { if cfg.HarnessToUse == harnessPi { if err := executePi(cfg); err != nil { @@ -211,7 +194,7 @@ func executeWithoutAuth(cfg ExecutionConfig) { cliArgs := cfg.HarnessArgs if cfg.Yolo { - cliArgs = append([]string{yoloModeFlag(cfg.HarnessToUse)}, cliArgs...) + cliArgs = append([]string{harness.YoloFlag(cfg.HarnessToUse)}, cliArgs...) } if cfg.HarnessToUse == harnessQwen { @@ -241,10 +224,9 @@ func executeWithoutAuth(cfg ExecutionConfig) { }) } - ctx, cancel := context.WithCancel(CLIContextFromCmd(cfg.Cmd).RootCtx()) + ctx, cancel, stopSig := execution.StartSession(CLIContextFromCmd(cfg.Cmd).RootCtx()) defer cancel() - stopSignalHandler := setupSignalHandler(cancel) - defer stopSignalHandler() + defer stopSig() execCmd := cfg.Deps.Process.ExecCommandContext(ctx, harnessPath, cliArgs...) execCmd.Env = cfg.ProviderEnv @@ -252,7 +234,7 @@ func executeWithoutAuth(cfg ExecutionConfig) { execCmd.Stdout = os.Stdout execCmd.Stderr = os.Stderr - displayName, _, _ := harnessDispatch(cfg.HarnessToUse, cfg.ProviderName, cfg.Provider.Model) + displayName, _, _ := harness.Dispatch(cfg.HarnessToUse, cfg.ProviderName, cfg.Provider.Model) if err := execCmd.Run(); err != nil { cfg.Cmd.Printf("Error running %s: %v\n", displayName, err) diff --git a/cmd/execution_harness_test.go b/cmd/execution_harness_test.go index b36bcba..a2abe22 100644 --- a/cmd/execution_harness_test.go +++ b/cmd/execution_harness_test.go @@ -10,19 +10,20 @@ import ( "testing" "github.com/dkmnx/kairo/internal/config" + "github.com/dkmnx/kairo/internal/harness" "github.com/dkmnx/kairo/internal/wrapper" ) func TestQwenAuthArgs(t *testing.T) { - args := qwenAuthArgs("qwen-plus") - if len(args) != 4 { - t.Fatalf("qwenAuthArgs should return 4 elements, got %d", len(args)) + _, _, extraArgs := harness.Dispatch(harness.Qwen, "test", "qwen-plus") + if len(extraArgs) != 4 { + t.Fatalf("Dispatch should return 4 elements, got %d", len(extraArgs)) } - if args[0] != "--auth-type" || args[1] != "anthropic" { - t.Errorf("first two args should be --auth-type anthropic, got %v", args[:2]) + if extraArgs[0] != "--auth-type" || extraArgs[1] != "anthropic" { + t.Errorf("first two args should be --auth-type anthropic, got %v", extraArgs[:2]) } - if args[2] != "--model" || args[3] != "qwen-plus" { - t.Errorf("last two args should be --model qwen-plus, got %v", args[2:]) + if extraArgs[2] != "--model" || extraArgs[3] != "qwen-plus" { + t.Errorf("last two args should be --model qwen-plus, got %v", extraArgs[2:]) } } diff --git a/cmd/harness.go b/cmd/harness.go index a858fd9..f38798d 100644 --- a/cmd/harness.go +++ b/cmd/harness.go @@ -1,25 +1,24 @@ package cmd import ( - "errors" "fmt" "strings" "github.com/dkmnx/kairo/internal/config" - kairoerrors "github.com/dkmnx/kairo/internal/errors" + "github.com/dkmnx/kairo/internal/harness" "github.com/dkmnx/kairo/internal/ui" "github.com/spf13/cobra" ) const ( - harnessClaude = "claude" - harnessQwen = "qwen" - harnessPi = "pi" - harnessCrush = "crush" + harnessClaude = harness.Claude + harnessQwen = harness.Qwen + harnessPi = harness.Pi + harnessCrush = harness.Crush ) func isValidHarness(name string) bool { - return name == harnessClaude || name == harnessQwen || name == harnessPi || name == harnessCrush + return harness.IsValid(name) } var harnessGetCmd = &cobra.Command{ @@ -65,17 +64,14 @@ var harnessSetCmd = &cobra.Command{ cliCtx := CLIContextFromCmd(cmd) - cfg, err := cliCtx.ConfigCache().Get(cliCtx.RootCtx(), dir) - if err != nil && !errors.Is(err, kairoerrors.ErrConfigNotFound) { - handleConfigError(cmd, err) + cfg, err := loadConfigOrEmpty(cmd) + if err != nil { + ui.PrintError(fmt.Sprintf("Error loading config: %v", err)) return } - if err != nil { - cfg = &config.Config{ - Providers: make(map[string]config.Provider), - DefaultModels: make(map[string]string), - } + if cfg == nil { + return } cfg.DefaultHarness = harnessName @@ -104,22 +100,10 @@ func init() { } func resolveHarness(flagHarness, configHarness string) string { - harness := flagHarness - if harness == "" { - harness = configHarness - } - if harness == "" { - return harnessClaude + h := harness.Resolve(flagHarness, configHarness) + if h != flagHarness && h != configHarness && h == harnessClaude && (flagHarness != "" || configHarness != "") { + ui.PrintWarn(fmt.Sprintf("Unknown harness '%s', using 'claude'", flagHarness)) } - if !isValidHarness(harness) { - ui.PrintWarn(fmt.Sprintf("Unknown harness '%s', using 'claude'", harness)) - - return harnessClaude - } - - return harness -} -func harnessBinary(harness string) string { - return harness + return h } diff --git a/cmd/harness_test.go b/cmd/harness_test.go index 9e1bd80..44d0b63 100644 --- a/cmd/harness_test.go +++ b/cmd/harness_test.go @@ -170,13 +170,6 @@ func TestGetHarnessWithPi(t *testing.T) { } } -func TestGetHarnessBinaryPi(t *testing.T) { - result := harnessBinary("pi") - if result != "pi" { - t.Errorf("harnessBinary('pi') = %q, want %q", result, "pi") - } -} - func TestHarnessSetCaseInsensitive(t *testing.T) { tests := []struct { name string @@ -250,28 +243,6 @@ func TestGetHarness(t *testing.T) { } } -func TestGetHarnessBinary(t *testing.T) { - tests := []struct { - name string - harness string - expected string - }{ - {"claude returns claude", "claude", "claude"}, - {"qwen returns qwen", "qwen", "qwen"}, - {"pi returns pi", "pi", "pi"}, - {"crush returns crush", "crush", "crush"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := harnessBinary(tt.harness) - if result != tt.expected { - t.Errorf("harnessBinary() = %q, want %q", result, tt.expected) - } - }) - } -} - func TestGetHarnessWithExistingConfig(t *testing.T) { originalConfigDir := configDir() defer func() { setConfigDir(originalConfigDir) }() diff --git a/cmd/interfaces.go b/cmd/interfaces.go index ef14e41..0e69c87 100644 --- a/cmd/interfaces.go +++ b/cmd/interfaces.go @@ -4,6 +4,7 @@ import ( "context" "os/exec" + "github.com/dkmnx/kairo/internal/crypto" "github.com/dkmnx/kairo/internal/update" "github.com/dkmnx/kairo/internal/wrapper" ) @@ -39,4 +40,5 @@ type Deps struct { Process ProcessRunner Wrapper WrapperService Update UpdateService + Crypto crypto.Service } diff --git a/cmd/list.go b/cmd/list.go index dab0a72..fabe1ca 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -1,9 +1,7 @@ package cmd import ( - stderrors "errors" "fmt" - "io/fs" "sort" "github.com/dkmnx/kairo/internal/config" @@ -17,23 +15,8 @@ var listCmd = &cobra.Command{ Short: "List configured providers", Long: "Display all configured providers and their status", Run: func(cmd *cobra.Command, args []string) { - cliCtx := CLIContextFromCmd(cmd) - dir := requireConfigDir(cmd) - if dir == "" { - ui.PrintInfo("Run 'kairo setup' to configure providers") - - return - } - - cfg, err := config.LoadConfig(cliCtx.RootCtx(), dir) - if err != nil { - if stderrors.Is(err, fs.ErrNotExist) { - printNoProvidersMessage() - - return - } - handleConfigError(cmd, err) - + cfg, err := loadConfigOrExit(cmd) + if err != nil || cfg == nil { return } diff --git a/cmd/root.go b/cmd/root.go index 7dd151d..8e08eeb 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -14,9 +14,9 @@ import ( ) var ( - harnessFlag string - yoloFlag bool - verboseFlag bool + harnessFlag string + skipPermissionsFlag bool + verboseFlag bool ) func setConfigDir(dir string) { @@ -79,7 +79,7 @@ func Execute() error { rootCmd.SetArgs(nil) }() - defaultCLIContext.SetDefaultProviderExplicit(hasDoubleDash(args)) + defaultCLIContext.SetDefaultProviderExplicit(hasArgsSeparator(args)) rootCmd.SetArgs(args) @@ -90,7 +90,7 @@ func init() { rootCmd.PersistentFlags().String("config", "", "Config directory (default is platform-specific)") rootCmd.PersistentFlags().BoolVarP(&verboseFlag, "verbose", "v", false, "Verbose output") rootCmd.Flags().StringVar(&harnessFlag, "harness", "", "CLI harness to use (claude, qwen, pi, or crush)") - rootCmd.Flags().BoolVarP(&yoloFlag, "yolo", "y", false, + rootCmd.Flags().BoolVarP(&skipPermissionsFlag, "yolo", "y", false, "Skip permission prompts (--dangerously-skip-permissions for Claude, --yolo for Qwen)") rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { @@ -196,11 +196,11 @@ func runPiProvider( Cmd: cmd, ProviderEnv: providerEnv, HarnessToUse: harnessToUse, - HarnessBinary: harnessBinary(harnessToUse), + HarnessBinary: harnessToUse, Provider: provider, ProviderName: providerName, HarnessArgs: harnessArgs, - Yolo: yoloFlag, + Yolo: skipPermissionsFlag, Deps: cliCtx.Deps(), } @@ -233,12 +233,12 @@ func runStandardProvider( Cmd: cmd, ProviderEnv: envResult.ProviderEnv, HarnessToUse: harnessToUse, - HarnessBinary: harnessBinary(harnessToUse), + HarnessBinary: harnessToUse, Provider: provider, ProviderName: providerName, HarnessArgs: harnessArgs, APIKey: apiKey, - Yolo: yoloFlag, + Yolo: skipPermissionsFlag, Deps: cliCtx.Deps(), } @@ -273,7 +273,7 @@ func splitArgs(args []string) ([]string, []string) { return args, nil } -func hasDoubleDash(args []string) bool { +func hasArgsSeparator(args []string) bool { for i := 0; i < len(args); i++ { if args[i] == "--" { return true @@ -295,30 +295,20 @@ func hasDoubleDash(args []string) bool { func providerFromArgs(cmd *cobra.Command, cfg *config.Config, args []string) (string, []string) { kairoArgs, harnessArgs := splitArgs(args) - switch { - case len(args) > 0 && strings.HasPrefix(args[0], "-") && cfg.DefaultProvider != "": - args = []string{cfg.DefaultProvider} - harnessArgs = kairoArgs - case len(kairoArgs) > 0 && len(args) > 1 && kairoArgs[0] != args[0]: - args = append([]string{args[0]}, kairoArgs...) - case len(args) > 1: - harnessArgs = args[1:] - args = args[:1] - } - - providerName := args[0] + if len(kairoArgs) > 0 && !strings.HasPrefix(kairoArgs[0], "-") { + harnessArgs = append(kairoArgs[1:], harnessArgs...) - if strings.HasPrefix(providerName, "-") { - if cfg.DefaultProvider == "" { - cmd.Println("Error: No default provider set and first argument looks like a flag") - cmd.Println("Run 'kairo setup' to configure a provider") + return kairoArgs[0], harnessArgs + } - return "", nil - } - providerName = cfg.DefaultProvider + if cfg.DefaultProvider != "" { + return cfg.DefaultProvider, kairoArgs } - return providerName, harnessArgs + cmd.Println("Error: No default provider set and first argument looks like a flag") + cmd.Println("Run 'kairo setup' to configure a provider") + + return "", nil } func resolveProviderAndArgs(cmd *cobra.Command, cfg *config.Config, args []string) ([]string, []string, string) { diff --git a/cmd/root_args_test.go b/cmd/root_args_test.go index bd3d6af..eea902f 100644 --- a/cmd/root_args_test.go +++ b/cmd/root_args_test.go @@ -87,7 +87,7 @@ func TestGetProviderFromArgs(t *testing.T) { } } -func TestHasDoubleDash(t *testing.T) { +func TestHasArgsSeparator(t *testing.T) { tests := []struct { name string args []string @@ -106,9 +106,9 @@ func TestHasDoubleDash(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := hasDoubleDash(tt.args) + got := hasArgsSeparator(tt.args) if got != tt.want { - t.Errorf("hasDoubleDash(%v) = %v, want %v", tt.args, got, tt.want) + t.Errorf("hasArgsSeparator(%v) = %v, want %v", tt.args, got, tt.want) } }) } diff --git a/cmd/root_cmd_test.go b/cmd/root_cmd_test.go index 31af170..b642b5d 100644 --- a/cmd/root_cmd_test.go +++ b/cmd/root_cmd_test.go @@ -212,9 +212,9 @@ func TestRootCmd(t *testing.T) { func TestRootCmdGetConfigDir(t *testing.T) { t.Run("returns flag value when set", func(t *testing.T) { - originalConfigDir := configDir() + originalDir := configDir() setConfigDir("/custom/config/dir") - defer func() { setConfigDir(originalConfigDir) }() + defer func() { setConfigDir(originalDir) }() result := configDir() if result != "/custom/config/dir" { @@ -222,26 +222,44 @@ func TestRootCmdGetConfigDir(t *testing.T) { } }) - t.Run("returns env default when flag is empty", func(t *testing.T) { - originalConfigDir := configDir() + t.Run("returns resolver default when flag is empty", func(t *testing.T) { + cliCtx := NewCLIContext() + cliCtx.SetConfigDirResolver(func() (string, error) { + return "/from/resolver", nil + }) + + // Override the global context to use our injected resolver + prevCtx := defaultCLIContext + defaultCLIContext = cliCtx setConfigDir("") - defer func() { setConfigDir(originalConfigDir) }() + defer func() { + defaultCLIContext = prevCtx + setConfigDir("") + }() result := configDir() - // We can't easily test the exact value without mocking env package - if result == "" { - t.Skip("Cannot test env.GetConfigDir() without mocking") + if result != "/from/resolver" { + t.Errorf("configDir() = %q, want %q", result, "/from/resolver") } }) - t.Run("empty flag value uses default", func(t *testing.T) { - originalConfigDir := configDir() + t.Run("empty resolver returns empty string", func(t *testing.T) { + cliCtx := NewCLIContext() + cliCtx.SetConfigDirResolver(func() (string, error) { + return "", nil + }) + + prevCtx := defaultCLIContext + defaultCLIContext = cliCtx setConfigDir("") - defer func() { setConfigDir(originalConfigDir) }() + defer func() { + defaultCLIContext = prevCtx + setConfigDir("") + }() result := configDir() - if result == "" { - t.Skip("Cannot mock env.GetConfigDir() without dependency injection") + if result != "" { + t.Errorf("configDir() = %q, want empty with nil resolver", result) } }) } diff --git a/cmd/root_provider_test.go b/cmd/root_provider_test.go index 64e801b..5424f8c 100644 --- a/cmd/root_provider_test.go +++ b/cmd/root_provider_test.go @@ -113,7 +113,7 @@ func TestRunPiProviderWithAuth(t *testing.T) { } }) cliCtx.SetDeps(d) - yoloFlag = false + skipPermissionsFlag = false harnessFlag = "" runPiProvider(rootCmd, cliCtx, cfg, cfg.Providers["zai"], "zai", "pi", []string{"hello"}) diff --git a/cmd/setup.go b/cmd/setup.go index 7210e7d..54f0219 100644 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -38,7 +38,7 @@ func configureProvider(params ProviderSetup) (string, error) { } apiKey := promptForAPIKey(promptCfg) - if err := validate.ValidateAPIKey(apiKey, definition.Name); err != nil { + if err := definition.ValidateAPIKey(apiKey); err != nil { return "", err } @@ -78,7 +78,7 @@ func configureProvider(params ProviderSetup) (string, error) { } params.Secrets[APIKeyEnvVarName(validatedName)] = apiKey - if err := SaveSecrets(params.CLIContext.RootCtx(), params.SecretsPath, params.KeyPath, params.Secrets); err != nil { + if err := SaveSecrets(params.CLIContext, params.SecretsPath, params.KeyPath, params.Secrets); err != nil { return "", err } @@ -100,7 +100,7 @@ func runResetSecrets(cliCtx *CLIContext, configDir string, secretsResult Secrets } if err := ResetSecretsFiles( - cliCtx.RootCtx(), configDir, secretsResult.SecretsPath, secretsResult.KeyPath, + cliCtx.RootCtx(), cliCtx, configDir, secretsResult.SecretsPath, secretsResult.KeyPath, ); err != nil { return err } @@ -137,7 +137,7 @@ var setupCmd = &cobra.Command{ return } - secretsResult, err := LoadSecrets(cliCtx.RootCtx(), configDir) + secretsResult, err := LoadSecrets(cliCtx, configDir) if err != nil { if setupResetSecrets { if err := runResetSecrets(cliCtx, configDir, secretsResult); err != nil { diff --git a/cmd/setup_config.go b/cmd/setup_config.go index c0f9f70..8452011 100644 --- a/cmd/setup_config.go +++ b/cmd/setup_config.go @@ -20,7 +20,7 @@ func EnsureConfigDir(cliCtx *CLIContext, configDir string) error { return kairoerrors.WrapError(kairoerrors.FileSystemError, "creating config directory", err) } - if err := crypto.EnsureKeyExists(cliCtx.RootCtx(), configDir); err != nil { + if err := cliCtx.Crypto().EnsureKeyExists(cliCtx.RootCtx(), configDir); err != nil { return kairoerrors.WrapError(kairoerrors.CryptoError, "creating encryption key", err) } @@ -79,7 +79,8 @@ type SecretsResult struct { } // LoadSecrets loads and decrypts secrets from the config directory. -func LoadSecrets(ctx context.Context, configDir string) (SecretsResult, error) { +func LoadSecrets(cliCtx *CLIContext, configDir string) (SecretsResult, error) { + ctx := cliCtx.RootCtx() result := SecretsResult{ Secrets: make(map[string]string), } @@ -91,7 +92,7 @@ func LoadSecrets(ctx context.Context, configDir string) (SecretsResult, error) { return result, nil } - existingSecrets, err := crypto.DecryptSecretsBytes(ctx, result.SecretsPath, result.KeyPath) + existingSecrets, err := cliCtx.Crypto().DecryptSecretsBytes(ctx, result.SecretsPath, result.KeyPath) if err != nil { return SecretsResult{}, err } @@ -106,7 +107,7 @@ func LoadSecrets(ctx context.Context, configDir string) (SecretsResult, error) { } // ResetSecretsFiles deletes and regenerates the encryption key and secrets files. -func ResetSecretsFiles(ctx context.Context, configDir, secretsPath, keyPath string) error { +func ResetSecretsFiles(ctx context.Context, cliCtx *CLIContext, configDir, secretsPath, keyPath string) error { if err := os.Remove(keyPath); err != nil && !errors.Is(err, fs.ErrNotExist) { return kairoerrors.WrapError(kairoerrors.FileSystemError, "failed to remove old key file", err) @@ -117,7 +118,7 @@ func ResetSecretsFiles(ctx context.Context, configDir, secretsPath, keyPath stri "failed to remove old secrets file", err) } - if err := crypto.EnsureKeyExists(ctx, configDir); err != nil { + if err := cliCtx.Crypto().EnsureKeyExists(ctx, configDir); err != nil { return kairoerrors.WrapError(kairoerrors.CryptoError, "failed to generate new encryption key", err) } @@ -126,9 +127,9 @@ func ResetSecretsFiles(ctx context.Context, configDir, secretsPath, keyPath stri } // SaveSecrets encrypts and writes the secrets map to the secrets file. -func SaveSecrets(ctx context.Context, secretsPath, keyPath string, secretsMap map[string]string) error { +func SaveSecrets(cliCtx *CLIContext, secretsPath, keyPath string, secretsMap map[string]string) error { secretsContent := secrets.Format(secretsMap) - if err := crypto.EncryptSecrets(ctx, secretsPath, keyPath, secretsContent); err != nil { + if err := cliCtx.Crypto().EncryptSecrets(cliCtx.RootCtx(), secretsPath, keyPath, secretsContent); err != nil { return kairoerrors.WrapError(kairoerrors.CryptoError, "saving secrets", err) } diff --git a/cmd/setup_config_test.go b/cmd/setup_config_test.go index bbb5a03..8ff9cce 100644 --- a/cmd/setup_config_test.go +++ b/cmd/setup_config_test.go @@ -7,21 +7,21 @@ import ( "testing" "github.com/dkmnx/kairo/internal/constants" - "github.com/dkmnx/kairo/internal/crypto" ) func TestResetSecretsFiles(t *testing.T) { t.Run("deletes old files and regenerates key", func(t *testing.T) { tmpDir := t.TempDir() + cliCtx := NewCLIContext() - if err := crypto.EnsureKeyExists(context.Background(), tmpDir); err != nil { + if err := cliCtx.Crypto().EnsureKeyExists(context.Background(), tmpDir); err != nil { t.Fatalf("EnsureKeyExists() error = %v", err) } keyPath := filepath.Join(tmpDir, constants.KeyFileName) secretsPath := filepath.Join(tmpDir, constants.SecretsFileName) - if err := crypto.EncryptSecrets(context.Background(), secretsPath, keyPath, "TEST_KEY=value\n"); err != nil { + if err := cliCtx.Crypto().EncryptSecrets(context.Background(), secretsPath, keyPath, "TEST_KEY=value\n"); err != nil { t.Fatalf("EncryptSecrets() error = %v", err) } @@ -34,7 +34,7 @@ func TestResetSecretsFiles(t *testing.T) { t.Fatalf("secrets file should exist before reset: %v", err) } - if err := ResetSecretsFiles(context.Background(), tmpDir, secretsPath, keyPath); err != nil { + if err := ResetSecretsFiles(context.Background(), cliCtx, tmpDir, secretsPath, keyPath); err != nil { t.Fatalf("ResetSecretsFiles() error = %v", err) } @@ -58,7 +58,8 @@ func TestResetSecretsFiles(t *testing.T) { keyPath := filepath.Join(tmpDir, constants.KeyFileName) secretsPath := filepath.Join(tmpDir, constants.SecretsFileName) - err := ResetSecretsFiles(context.Background(), tmpDir, secretsPath, keyPath) + cliCtx := NewCLIContext() + err := ResetSecretsFiles(context.Background(), cliCtx, tmpDir, secretsPath, keyPath) if err != nil { t.Fatalf("ResetSecretsFiles() should succeed when files don't exist, got: %v", err) } @@ -108,8 +109,9 @@ func TestEnsureConfigDir(t *testing.T) { func TestSaveSecrets(t *testing.T) { t.Run("encrypts and saves secrets", func(t *testing.T) { tmpDir := t.TempDir() + cliCtx := NewCLIContext() - if err := crypto.EnsureKeyExists(context.Background(), tmpDir); err != nil { + if err := cliCtx.Crypto().EnsureKeyExists(context.Background(), tmpDir); err != nil { t.Fatalf("EnsureKeyExists() error = %v", err) } @@ -120,7 +122,7 @@ func TestSaveSecrets(t *testing.T) { "ZAI_API_KEY": "sk-test-123", } - err := SaveSecrets(context.Background(), secretsPath, keyPath, secrets) + err := SaveSecrets(cliCtx, secretsPath, keyPath, secrets) if err != nil { t.Fatalf("SaveSecrets() error = %v", err) } @@ -129,7 +131,7 @@ func TestSaveSecrets(t *testing.T) { t.Errorf("secrets file should exist: %v", err) } - decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) + decrypted, err := cliCtx.Crypto().DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { t.Fatalf("DecryptSecrets() error = %v", err) } @@ -140,7 +142,8 @@ func TestSaveSecrets(t *testing.T) { }) t.Run("error with invalid key path", func(t *testing.T) { - err := SaveSecrets(context.Background(), "/nonexistent/secrets", "/nonexistent/key", map[string]string{"K": "V"}) + cliCtx := NewCLIContext() + err := SaveSecrets(cliCtx, "/nonexistent/secrets", "/nonexistent/key", map[string]string{"K": "V"}) if err == nil { t.Error("SaveSecrets() should fail with invalid key path") } diff --git a/cmd/setup_helpers_test.go b/cmd/setup_helpers_test.go index 2a871b5..e6b0ca8 100644 --- a/cmd/setup_helpers_test.go +++ b/cmd/setup_helpers_test.go @@ -8,7 +8,6 @@ import ( "github.com/dkmnx/kairo/internal/config" "github.com/dkmnx/kairo/internal/constants" - "github.com/dkmnx/kairo/internal/crypto" "github.com/dkmnx/kairo/internal/providers" secretspkg "github.com/dkmnx/kairo/internal/secrets" "github.com/dkmnx/kairo/internal/validate" @@ -290,7 +289,8 @@ func TestSaveProviderConfiguration(t *testing.T) { t.Run("saves new provider and becomes default", func(t *testing.T) { tmpDir := t.TempDir() - if err := crypto.EnsureKeyExists(context.Background(), tmpDir); err != nil { + cliCtx := NewCLIContext() + if err := cliCtx.Crypto().EnsureKeyExists(context.Background(), tmpDir); err != nil { t.Fatalf("EnsureKeyExists() error = %v", err) } @@ -320,7 +320,7 @@ func TestSaveProviderConfiguration(t *testing.T) { secrets := make(map[string]string) secrets["TESTPROVIDER_API_KEY"] = "test-api-key" - err = SaveSecrets(context.Background(), secretsPath, keyPath, secrets) + err = SaveSecrets(cliCtx, secretsPath, keyPath, secrets) if err != nil { t.Fatalf("SaveSecrets() error = %v", err) } @@ -329,7 +329,7 @@ func TestSaveProviderConfiguration(t *testing.T) { t.Errorf("DefaultProvider = %q, want %q", cfg.DefaultProvider, "testprovider") } - result, err := LoadSecrets(context.Background(), tmpDir) + result, err := LoadSecrets(cliCtx, tmpDir) if err != nil { t.Fatalf("LoadSecrets() error = %v", err) } @@ -342,7 +342,8 @@ func TestSaveProviderConfiguration(t *testing.T) { t.Run("saves provider without becoming default when default exists", func(t *testing.T) { tmpDir := t.TempDir() - if err := crypto.EnsureKeyExists(context.Background(), tmpDir); err != nil { + cliCtx := NewCLIContext() + if err := cliCtx.Crypto().EnsureKeyExists(context.Background(), tmpDir); err != nil { t.Fatalf("EnsureKeyExists() error = %v", err) } @@ -374,7 +375,7 @@ func TestSaveProviderConfiguration(t *testing.T) { secrets := make(map[string]string) secrets["NEWPROVIDER_API_KEY"] = "new-api-key" - err = SaveSecrets(context.Background(), secretsPath, keyPath, secrets) + err = SaveSecrets(cliCtx, secretsPath, keyPath, secrets) if err != nil { t.Fatalf("SaveSecrets() error = %v", err) } diff --git a/cmd/setup_secrets_test.go b/cmd/setup_secrets_test.go index 1d28e03..2c1d268 100644 --- a/cmd/setup_secrets_test.go +++ b/cmd/setup_secrets_test.go @@ -147,7 +147,7 @@ func TestLoadSecrets(t *testing.T) { t.Fatal(err) } - result, err := LoadSecrets(context.Background(), tmpDir) + result, err := LoadSecrets(NewCLIContext(), tmpDir) if err != nil { t.Fatalf("LoadSecrets() error = %v", err) } @@ -173,7 +173,7 @@ func TestLoadSecretsNoSecretsFile(t *testing.T) { t.Fatal(err) } - result, err := LoadSecrets(context.Background(), tmpDir) + result, err := LoadSecrets(NewCLIContext(), tmpDir) if err != nil { t.Fatalf("LoadSecrets() error = %v", err) } @@ -205,7 +205,7 @@ func TestLoadSecretsWithCorruptedFile(t *testing.T) { t.Fatal(err) } - _, err := LoadSecrets(context.Background(), tmpDir) + _, err := LoadSecrets(NewCLIContext(), tmpDir) if err == nil { t.Fatal("Expected error for corrupted secrets file, got nil") } @@ -229,7 +229,7 @@ func TestLoadSecretsWithCorruptedKey(t *testing.T) { t.Fatal(err) } - _, err := LoadSecrets(context.Background(), tmpDir) + _, err := LoadSecrets(NewCLIContext(), tmpDir) if err == nil { t.Fatal("Expected error for corrupted key file, got nil") } diff --git a/cmd/test_helpers.go b/cmd/test_helpers.go index e938556..05fea45 100644 --- a/cmd/test_helpers.go +++ b/cmd/test_helpers.go @@ -4,6 +4,7 @@ import ( "context" "os/exec" + "github.com/dkmnx/kairo/internal/crypto" "github.com/dkmnx/kairo/internal/update" "github.com/dkmnx/kairo/internal/wrapper" ) @@ -93,5 +94,5 @@ func testDeps(overrides ...func(mp *mockProcess, mw *mockWrapper, mu *mockUpdate fn(mp, mw, mu) } - return &Deps{Process: mp, Wrapper: mw, Update: mu} + return &Deps{Process: mp, Wrapper: mw, Update: mu, Crypto: crypto.DefaultService{}} } diff --git a/cmd/util.go b/cmd/util.go index ae951ea..35b7d0c 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -2,14 +2,12 @@ package cmd import ( stderrors "errors" - "io/fs" "os" - "os/signal" "strings" - "syscall" "github.com/dkmnx/kairo/internal/config" "github.com/dkmnx/kairo/internal/constants" + kairoerrors "github.com/dkmnx/kairo/internal/errors" "github.com/dkmnx/kairo/internal/ui" "github.com/spf13/cobra" ) @@ -46,7 +44,7 @@ func loadConfigOrExit(cmd *cobra.Command) (*config.Config, error) { cliCtx := CLIContextFromCmd(cmd) cfg, err := cliCtx.ConfigCache().Get(cliCtx.RootCtx(), dir) if err != nil { - if stderrors.Is(err, fs.ErrNotExist) { + if stderrors.Is(err, kairoerrors.ErrConfigNotFound) { printNoProvidersMessage() return nil, nil @@ -60,6 +58,21 @@ func loadConfigOrExit(cmd *cobra.Command) (*config.Config, error) { return cfg, nil } +func loadConfigOrEmpty(cmd *cobra.Command) (*config.Config, error) { + cfg, err := loadConfigOrExit(cmd) + if err != nil { + return nil, err + } + if cfg == nil { + return &config.Config{ + Providers: make(map[string]config.Provider), + DefaultModels: make(map[string]string), + }, nil + } + + return cfg, nil +} + // printNoProvidersMessage prints a standard message indicating no providers // are configured and directs the user to run setup. func printNoProvidersMessage() { @@ -120,26 +133,3 @@ func mergeEnvVars(envs ...[]string) []string { return deduped } - -// setupSignalHandler registers a goroutine that calls cancel on SIGINT or -// SIGTERM. It returns a stop function that should be called for cleanup when -// the command completes before any signal is received. -func setupSignalHandler(cancel func()) func() { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - done := make(chan struct{}) - go func() { - select { - case <-sigChan: - signal.Stop(sigChan) - if cancel != nil { - cancel() - } - case <-done: - signal.Stop(sigChan) - } - }() - - return func() { close(done) } -} diff --git a/cmd/util_test.go b/cmd/util_test.go index 3c443ad..0cad947 100644 --- a/cmd/util_test.go +++ b/cmd/util_test.go @@ -126,15 +126,6 @@ func TestRunningWithRaceDetector(t *testing.T) { } } -func TestSetupSignalHandler(t *testing.T) { - t.Run("signal handler sets up without panic", func(t *testing.T) { - stop := setupSignalHandler(func() { - // cancel callback - }) - stop() - }) -} - func TestPrintSecretsRecoveryHelp(t *testing.T) { printSecretsRecoveryHelp() } diff --git a/docs/contributing/README.md b/docs/contributing/README.md index d777c93..4f174d5 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -71,10 +71,14 @@ Closes #42 ## Code Style -- Follow [Effective Go](https://go.dev/doc/effective_go) -- Use `gofmt` -- Add godoc comments -- Return typed errors +- Follow [Google Go Style Guide](https://google.github.io/styleguide/go) +- Use `gofmt` for formatting +- Use `golangci-lint` (see `.golangci.yml`) +- MixedCaps naming, no `Get` prefix on getters +- Short receiver names (1-2 letters), consistent per type +- Doc comments on all top-level exported names +- Indent error flow, early returns +- Return typed errors from `internal/` packages ## Testing diff --git a/docs/guides/development-guide.md b/docs/guides/development-guide.md index 752dc1a..9971dec 100644 --- a/docs/guides/development-guide.md +++ b/docs/guides/development-guide.md @@ -66,6 +66,26 @@ kairo/ ## Adding a Provider +### Custom providers (config only) + +Define providers in `config.yaml` under `custom_providers` — no code changes needed: + +```yaml +custom_providers: + my-llm: + name: My LLM + base_url: https://api.example.com/anthropic + model: custom-model + requires_api_key: true + api_key_env_var: MY_LLM_API_KEY + min_key_length: 32 + key_prefix: sk- +``` + +Custom providers override built-in providers with the same key. + +### Built-in providers (code) + 1. Add the provider definition in `internal/providers/registry.go`: ```go @@ -76,22 +96,22 @@ var builtInProviders = map[string]ProviderDefinition{ BaseURL: "https://api.newprovider.com/anthropic", Model: "new-model", RequiresAPIKey: true, + APIKeyEnvVar: "NEWPROVIDER_API_KEY", + KeyFormat: KeyFormatMin32, }, } ``` 1. Add the provider key to `providerOrder` in the same file so it appears in setup menus. -2. If needed, add provider-specific API key validation in `internal/validate/api_key.go`. - -3. Update user and reference docs. - -4. Run targeted tests: +2. Run targeted tests: ```bash go test ./internal/providers/... ./internal/validate/... ``` +1. Update user and reference docs. + ## Testing ```bash diff --git a/docs/guides/user-guide.md b/docs/guides/user-guide.md index df0a4c9..7885293 100644 --- a/docs/guides/user-guide.md +++ b/docs/guides/user-guide.md @@ -81,11 +81,10 @@ kairo -- "Quick question" | `kairo [args]` | Execute with a specific provider | | `kairo -- [args]` | Execute with the default provider | | `kairo harness get` | Get current harness | -| `kairo harness set ` | Set default harness (`claude`, `qwen`, `pi`, or `crush`) | +| `kairo harness set ` | Set default harness (claude, qwen, pi, or crush) | | `kairo update` | Update to the latest version | | `kairo version` | Show version | | `kairo completion [shell]` | Generate shell completion script | - ## Supported Providers | Provider | API Key Env Var | API Key Required | @@ -109,9 +108,7 @@ kairo -- "Quick question" | `fireworks` | `FIREWORKS_API_KEY` | Yes | | `azure-openai-responses` | `AZURE_OPENAI_API_KEY` | Yes | | `minimax-cn` | `MINIMAX_CN_API_KEY` | Yes | - | `custom` | user-defined | Yes | - Details: [Provider Reference](../reference/providers.md) ## Configuration @@ -122,7 +119,6 @@ Details: [Provider Reference](../reference/providers.md) | ----------- | ---------------------------------------- | | Linux/macOS | `~/.config/kairo/` | | Windows | `%USERPROFILE%\AppData\Roaming\kairo\` | - ### Files | File | Purpose | @@ -130,7 +126,6 @@ Details: [Provider Reference](../reference/providers.md) | `config.yaml` | Provider and harness settings | | `secrets.age` | Encrypted API keys | | `age.key` | Encryption private key | - Details: [Configuration Reference](../reference/configuration.md) ## Security @@ -168,7 +163,6 @@ Common issues: | `provider not found` | Run `kairo setup` | | `invalid API key` | Reconfigure with `kairo setup` | | `failed to decrypt` | Restore backup or run `kairo setup --reset-secrets` | - Full guide: [Troubleshooting](../troubleshooting/README.md) ## Next Steps diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 81736fc..0e995c5 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -37,13 +37,25 @@ providers: model: string env_vars: - KEY=value +custom_providers: + : + name: string + base_url: string + model: string + requires_api_key: true + api_key_env_var: string + min_key_length: number + key_prefix: string + key_pattern: string + env_vars: + - KEY=value ``` Notes: - `default_harness` is optional. If omitted, Kairo uses `claude`. Valid values: `claude`, `qwen`, `pi`, `crush`. - `default_models` is optional migration metadata maintained for built-in providers. -- API keys are not stored in `config.yaml`. +- `custom_providers` is optional. Custom provider definitions are validated at startup and merged into the provider registry. Custom entries with the same key as a built-in provider override the built-in definition. ### Example @@ -69,6 +81,38 @@ providers: - ANTHROPIC_SMALL_FAST_MAX_TOKENS=24576 ``` +## Custom Providers + +Define provider definitions directly in `config.yaml` without recompiling Kairo. Custom providers override built-in providers with the same key. + +```yaml +custom_providers: + my-llm: + name: My LLM + base_url: https://api.example.com/anthropic + model: custom-model + requires_api_key: true + api_key_env_var: MY_LLM_API_KEY + min_key_length: 32 + key_prefix: sk- + env_vars: + - EXTRA_VAR=value +``` + +Fields: + +| Field | Required | Default | Description | +| ------------------ | -------- | ------- | ------------------------------------------------- | +| `name` | Yes | — | Display name shown in setup and list commands | +| `base_url` | No | `""` | Anthropic-compatible endpoint (HTTPS only) | +| `model` | No | `""` | Default model (user can override during setup) | +| `requires_api_key` | No | `true` | Whether an API key is required | +| `api_key_env_var` | No | `""` | Environment variable name for the API key | +| `min_key_length` | No | `20` | Minimum API key length | +| `key_prefix` | No | `""` | Required API key prefix (e.g. `sk-`) | +| `key_pattern` | No | `""` | Regex pattern the API key must match | +| `env_vars` | No | `[]` | Extra environment variables passed to the harness | + ## `secrets.age` Encrypted API keys using age/X25519. @@ -107,7 +151,7 @@ Generated on first setup. The file contains the private identity line followed b ## Custom Provider -Required fields: +Required fields when using `kairo setup`: - `base_url`: HTTPS endpoint - `model`: Required for custom providers diff --git a/docs/reference/providers.md b/docs/reference/providers.md index 7e95a7d..b521902 100644 --- a/docs/reference/providers.md +++ b/docs/reference/providers.md @@ -5,7 +5,7 @@ Built-in and custom provider configurations. ## Built-in Providers | Provider | API Key Env Var | Default Model | API Key | -| :----------------------- | :--------------------- | :-------------------- | :------ | +| ------------------------ | ---------------------- | --------------------- | ------- | | `zai` | `ZAI_API_KEY` | `glm-5.1` | Yes | | `minimax` | `MINIMAX_API_KEY` | `MiniMax-M2.7` | Yes | | `kimi` | `KIMI_API_KEY` | `kimi-for-coding` | Yes | @@ -26,6 +26,7 @@ Built-in and custom provider configurations. | `azure-openai-responses` | `AZURE_OPENAI_API_KEY` | (provider-managed) | Yes | | `minimax-cn` | `MINIMAX_CN_API_KEY` | (provider-managed) | Yes | | `custom` | user-defined | user-defined | Yes | + Providers without default base URLs and models (marked "provider-managed") are passed through to the harness CLI directly. The harness manages its own endpoint and model selection for these providers. ## Provider Details @@ -94,6 +95,26 @@ kairo my-provider "Your query" ## Adding a New Provider +### Custom Provider via config.yaml + +Define providers in `~/.config/kairo/config.yaml` under `custom_providers`: + +```yaml +custom_providers: + my-llm: + name: My LLM + base_url: https://api.example.com/anthropic + model: custom-model + requires_api_key: true + api_key_env_var: MY_LLM_API_KEY + min_key_length: 32 + key_prefix: sk- +``` + +Then run `kairo setup` — the custom provider appears in the dropdown. Custom providers override built-in providers with the same key, letting you patch defaults (e.g., model name) without recompiling. + +### Built-in Provider via code + 1. Define the provider in `internal/providers/registry.go`: ```go @@ -104,6 +125,8 @@ var builtInProviders = map[string]ProviderDefinition{ BaseURL: "https://api.newprovider.com/anthropic", Model: "new-model", RequiresAPIKey: true, + APIKeyEnvVar: "NEWPROVIDER_API_KEY", + KeyFormat: KeyFormatMin32, }, } ``` diff --git a/docs/troubleshooting/README.md b/docs/troubleshooting/README.md index 70721d1..ba98a9b 100644 --- a/docs/troubleshooting/README.md +++ b/docs/troubleshooting/README.md @@ -115,7 +115,7 @@ Install Qwen Code. ### `crush: command not found` -Install Crush. See https://github.com/charmbracelet/crush#installation +Install Crush. See ### Execution Failed diff --git a/go.mod b/go.mod index dd44a88..f7280c6 100644 --- a/go.mod +++ b/go.mod @@ -20,8 +20,8 @@ require ( github.com/mattn/go-runewidth v0.0.23 // indirect github.com/mattn/go-tty v0.0.8 // indirect github.com/spf13/pflag v1.0.10 // indirect - golang.org/x/crypto v0.51.0 // indirect - golang.org/x/sys v0.44.0 // indirect + golang.org/x/crypto v0.52.0 // indirect + golang.org/x/sys v0.45.0 // indirect golang.org/x/term v0.43.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index dd1f75c..5163928 100644 --- a/go.sum +++ b/go.sum @@ -44,10 +44,10 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD github.com/yarlson/tap v0.13.1 h1:ghvYnWTPxts0w6qdZEXr/6gkYHTBT/3rElFVjuZLqj8= github.com/yarlson/tap v0.13.1/go.mod h1:AuqXWK8npVwIM6spv9unFmQnz0koSrw7iU990bIQ0XY= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= -golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= -golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= -golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= +golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/config/cache.go b/internal/config/cache.go index 5f9fda2..3a59cca 100644 --- a/internal/config/cache.go +++ b/internal/config/cache.go @@ -8,6 +8,7 @@ import ( "time" "github.com/dkmnx/kairo/internal/errors" + "github.com/dkmnx/kairo/internal/providers" ) // cachedConfig holds a single cached configuration entry. @@ -36,9 +37,9 @@ func deepCopyConfig(cfg *Config) *Config { if cfg == nil { return nil } - providers := make(map[string]Provider, len(cfg.Providers)) + provs := make(map[string]Provider, len(cfg.Providers)) for k, v := range cfg.Providers { - providers[k] = Provider{ + provs[k] = Provider{ Name: v.Name, BaseURL: v.BaseURL, Model: v.Model, @@ -49,11 +50,17 @@ func deepCopyConfig(cfg *Config) *Config { defaultModels := make(map[string]string, len(cfg.DefaultModels)) maps.Copy(defaultModels, cfg.DefaultModels) + customProvs := make(map[string]providers.CustomProviderDefinition, len(cfg.CustomProviders)) + for k := range cfg.CustomProviders { + customProvs[k] = cfg.CustomProviders[k] + } + return &Config{ DefaultProvider: cfg.DefaultProvider, - Providers: providers, + Providers: provs, DefaultModels: defaultModels, DefaultHarness: cfg.DefaultHarness, + CustomProviders: customProvs, } } @@ -77,6 +84,10 @@ func (c *ConfigCache) Get(ctx context.Context, configDir string) (*Config, error WithContext("config_dir", configDir) } + if len(cfg.CustomProviders) > 0 { + providers.DefaultRegistry.RegisterCustom(cfg.CustomProviders) + } + c.mu.Lock() c.entries[configDir] = &cachedConfig{ config: cfg, @@ -93,4 +104,6 @@ func (c *ConfigCache) Invalidate(configDir string) { c.mu.Lock() delete(c.entries, configDir) c.mu.Unlock() + + providers.DefaultRegistry.ClearCustom() } diff --git a/internal/config/env.go b/internal/config/env.go index 5dec312..2280051 100644 --- a/internal/config/env.go +++ b/internal/config/env.go @@ -9,6 +9,11 @@ import ( "github.com/dkmnx/kairo/internal/errors" ) +// DefaultConfigDir resolves the platform-specific default configuration directory. +func DefaultConfigDir() (string, error) { + return ConfigDir() +} + // ConfigDir returns the platform-specific default kairo configuration directory. func ConfigDir() (string, error) { home, err := os.UserHomeDir() diff --git a/internal/config/loader.go b/internal/config/loader.go index b2ec7f2..14b8a45 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -11,15 +11,17 @@ import ( "github.com/dkmnx/kairo/internal/errors" "github.com/dkmnx/kairo/internal/fsutil" + "github.com/dkmnx/kairo/internal/providers" "gopkg.in/yaml.v3" ) // Config represents the top-level kairo configuration file. type Config struct { - DefaultProvider string `yaml:"default_provider"` - Providers map[string]Provider `yaml:"providers"` - DefaultModels map[string]string `yaml:"default_models"` - DefaultHarness string `yaml:"default_harness,omitempty"` + DefaultProvider string `yaml:"default_provider"` + Providers map[string]Provider `yaml:"providers"` + DefaultModels map[string]string `yaml:"default_models"` + DefaultHarness string `yaml:"default_harness,omitempty"` + CustomProviders map[string]providers.CustomProviderDefinition `yaml:"custom_providers"` } // Provider represents a single provider's configuration entry. diff --git a/internal/crypto/service.go b/internal/crypto/service.go new file mode 100644 index 0000000..bbad4cf --- /dev/null +++ b/internal/crypto/service.go @@ -0,0 +1,33 @@ +package crypto + +import "context" + +type Service interface { + GenerateKey(ctx context.Context, keyPath string) error + EncryptSecrets(ctx context.Context, secretsPath, keyPath, secrets string) error + DecryptSecrets(ctx context.Context, secretsPath, keyPath string) (string, error) + DecryptSecretsBytes(ctx context.Context, secretsPath, keyPath string) ([]byte, error) + EnsureKeyExists(ctx context.Context, configDir string) error +} + +type DefaultService struct{} + +func (DefaultService) GenerateKey(ctx context.Context, keyPath string) error { + return GenerateKey(ctx, keyPath) +} + +func (DefaultService) EncryptSecrets(ctx context.Context, secretsPath, keyPath, secrets string) error { + return EncryptSecrets(ctx, secretsPath, keyPath, secrets) +} + +func (DefaultService) DecryptSecrets(ctx context.Context, secretsPath, keyPath string) (string, error) { + return DecryptSecrets(ctx, secretsPath, keyPath) +} + +func (DefaultService) DecryptSecretsBytes(ctx context.Context, secretsPath, keyPath string) ([]byte, error) { + return DecryptSecretsBytes(ctx, secretsPath, keyPath) +} + +func (DefaultService) EnsureKeyExists(ctx context.Context, configDir string) error { + return EnsureKeyExists(ctx, configDir) +} diff --git a/internal/execution/signal.go b/internal/execution/signal.go new file mode 100644 index 0000000..0ad70d5 --- /dev/null +++ b/internal/execution/signal.go @@ -0,0 +1,33 @@ +package execution + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +// StartSession sets up a context that cancels on SIGINT/SIGTERM. +func StartSession(parent context.Context) (ctx context.Context, cancel context.CancelFunc, stop func()) { + ctx, cancel = context.WithCancel(parent) + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + + done := make(chan struct{}) + stop = func() { + signal.Stop(ch) + close(done) + } + + go func() { + select { + case <-ch: + signal.Stop(ch) + cancel() + case <-done: + signal.Stop(ch) + } + }() + + return ctx, cancel, stop +} diff --git a/internal/execution/signal_test.go b/internal/execution/signal_test.go new file mode 100644 index 0000000..2eb9150 --- /dev/null +++ b/internal/execution/signal_test.go @@ -0,0 +1,88 @@ +package execution + +import ( + "context" + "os" + "runtime" + "syscall" + "testing" + "time" +) + +func TestStartSession_CancelOnSignal(t *testing.T) { + parent := context.Background() + ctx, cancel, stop := StartSession(parent) + defer cancel() + defer stop() + + if ctx == nil { + t.Fatal("StartSession returned nil context") + } + if cancel == nil { + t.Fatal("StartSession returned nil cancel") + } + if stop == nil { + t.Fatal("StartSession returned nil stop") + } + + if err := ctx.Err(); err != nil { + t.Fatalf("new session context already canceled: %v", err) + } +} + +func TestStartSession_StopCleansUp(t *testing.T) { + parent := context.Background() + ctx, cancel, stop := StartSession(parent) + defer cancel() + + stop() + + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + t.Fatal("context should not be canceled after stop without signal") + } +} + +func TestStartSession_SignalCancelsContext(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("sending signals to self is not supported on Windows") + } + parent := context.Background() + ctx, cancel, stop := StartSession(parent) + defer cancel() + defer stop() + + proc, err := os.FindProcess(os.Getpid()) + if err != nil { + t.Fatalf("finding process: %v", err) + } + + if err := proc.Signal(syscall.SIGINT); err != nil { + t.Fatalf("sending SIGINT: %v", err) + } + + select { + case <-ctx.Done(): + if err := ctx.Err(); err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("context was not canceled after SIGINT") + } +} + +func TestStartSession_ParentCancellation(t *testing.T) { + parent, parentCancel := context.WithCancel(context.Background()) + ctx, cancel, stop := StartSession(parent) + defer cancel() + defer stop() + + parentCancel() + + select { + case <-ctx.Done(): + case <-time.After(2 * time.Second): + t.Fatal("child context was not canceled when parent canceled") + } +} diff --git a/internal/harness/harness.go b/internal/harness/harness.go new file mode 100644 index 0000000..07600f5 --- /dev/null +++ b/internal/harness/harness.go @@ -0,0 +1,87 @@ +package harness + +import ( + "fmt" + "strings" +) + +const ( + Claude = "claude" + Qwen = "qwen" + Pi = "pi" + Crush = "crush" +) + +// IsValid reports whether name is one of the supported harnesses. +func IsValid(name string) bool { + return name == Claude || name == Qwen || name == Pi || name == Crush +} + +// Resolve returns the effective harness given a flag override and config default. +func Resolve(flagHarness, configHarness string) string { + h := flagHarness + if h == "" { + h = configHarness + } + if h == "" { + return Claude + } + if !IsValid(h) { + return Claude + } + + return h +} + +// Dispatch returns the display name, environment variable name, and any extra +// CLI arguments for the given harness configuration. +func Dispatch(h, providerName, model string) (displayName, envVarName string, extraArgs []string) { + switch h { + case Qwen: + return "Qwen", "ANTHROPIC_API_KEY", []string{"--auth-type", "anthropic", "--model", model} + case Crush: + return "Crush", APIKeyEnvVar(providerName), nil + case Pi: + return "Pi", "", nil + default: + return "Claude", "", nil + } +} + +// YoloFlag returns the harness-specific flag for skipping permission prompts. +func YoloFlag(h string) string { + switch h { + case Qwen, Crush: + return "--yolo" + case Pi: + return "" + default: + return "--dangerously-skip-permissions" + } +} + +// PiEnvVars returns environment variables for the Pi harness. +func PiEnvVars(providerName, model string) []string { + return []string{ + fmt.Sprintf("PI_PROVIDER=%s", providerName), + fmt.Sprintf("PI_MODEL=%s", model), + } +} + +// BuiltInEnvVars returns environment variables for Anthropic-compatible providers. +func BuiltInEnvVars(baseURL, model string) []string { + return []string{ + fmt.Sprintf("ANTHROPIC_BASE_URL=%s", baseURL), + fmt.Sprintf("ANTHROPIC_MODEL=%s", model), + fmt.Sprintf("ANTHROPIC_HAIKU_MODEL=%s", model), + fmt.Sprintf("ANTHROPIC_SONNET_MODEL=%s", model), + fmt.Sprintf("ANTHROPIC_OPUS_MODEL=%s", model), + fmt.Sprintf("ANTHROPIC_SMALL_FAST_MODEL=%s", model), + "NODE_OPTIONS=--no-deprecation", + } +} + +// APIKeyEnvVar returns the conventional environment variable name for a provider's API key. +func APIKeyEnvVar(providerName string) string { + return fmt.Sprintf("%s_API_KEY", strings.ToUpper(providerName)) +} diff --git a/internal/harness/harness_test.go b/internal/harness/harness_test.go new file mode 100644 index 0000000..2afb78e --- /dev/null +++ b/internal/harness/harness_test.go @@ -0,0 +1,148 @@ +package harness + +import "testing" + +func TestIsValid(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"claude", Claude, true}, + {"qwen", Qwen, true}, + {"pi", Pi, true}, + {"crush", Crush, true}, + {"empty", "", false}, + {"unknown", "unknown", false}, + {"partial", "claud", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsValid(tt.input) + if got != tt.want { + t.Errorf("IsValid(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestResolve(t *testing.T) { + tests := []struct { + name string + flag string + config string + want string + }{ + {"flag takes precedence", Qwen, Claude, Qwen}, + {"config fallback", "", Qwen, Qwen}, + {"both empty defaults to claude", "", "", Claude}, + {"unknown flag defaults to claude", "unknown", "", Claude}, + {"unknown config defaults to claude", "", "unknown", Claude}, + {"pi over config", Pi, Claude, Pi}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Resolve(tt.flag, tt.config) + if got != tt.want { + t.Errorf("Resolve(%q, %q) = %q, want %q", tt.flag, tt.config, got, tt.want) + } + }) + } +} + +func TestDispatch(t *testing.T) { + tests := []struct { + name string + harness string + providerName string + model string + wantDisplay string + wantEnv string + wantExtraLen int + }{ + { + name: "claude", harness: Claude, providerName: "test", + wantDisplay: "Claude", wantEnv: "", wantExtraLen: 0, + }, + { + name: "qwen", harness: Qwen, providerName: "test", model: "qwen-plus", + wantDisplay: "Qwen", wantEnv: "ANTHROPIC_API_KEY", wantExtraLen: 4, + }, + { + name: "pi", harness: Pi, providerName: "test", + wantDisplay: "Pi", wantEnv: "", wantExtraLen: 0, + }, + { + name: "crush", harness: Crush, providerName: "test", + wantDisplay: "Crush", wantEnv: "TEST_API_KEY", wantExtraLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d, e, x := Dispatch(tt.harness, tt.providerName, tt.model) + if d != tt.wantDisplay { + t.Errorf("display = %q, want %q", d, tt.wantDisplay) + } + if e != tt.wantEnv { + t.Errorf("envVar = %q, want %q", e, tt.wantEnv) + } + if len(x) != tt.wantExtraLen { + t.Errorf("extraArgs len = %d, want %d", len(x), tt.wantExtraLen) + } + }) + } +} + +func TestYoloFlag(t *testing.T) { + tests := []struct { + name string + harness string + want string + }{ + {"claude", Claude, "--dangerously-skip-permissions"}, + {"qwen", Qwen, "--yolo"}, + {"pi", Pi, ""}, + {"crush", Crush, "--yolo"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := YoloFlag(tt.harness) + if got != tt.want { + t.Errorf("YoloFlag(%q) = %q, want %q", tt.harness, got, tt.want) + } + }) + } +} + +func TestPiEnvVars(t *testing.T) { + vars := PiEnvVars("zai", "glm-5") + if len(vars) != 2 { + t.Fatalf("expected 2 env vars, got %d", len(vars)) + } + if vars[0] != "PI_PROVIDER=zai" { + t.Errorf("PI_PROVIDER = %q", vars[0]) + } + if vars[1] != "PI_MODEL=glm-5" { + t.Errorf("PI_MODEL = %q", vars[1]) + } +} + +func TestBuiltInEnvVars(t *testing.T) { + vars := BuiltInEnvVars("https://api.example.com", "test-model") + if len(vars) != 7 { + t.Fatalf("expected 7 env vars, got %d", len(vars)) + } + if vars[0] != "ANTHROPIC_BASE_URL=https://api.example.com" { + t.Errorf("first var = %q", vars[0]) + } +} + +func TestAPIKeyEnvVar(t *testing.T) { + if got := APIKeyEnvVar("testprovider"); got != "TESTPROVIDER_API_KEY" { + t.Errorf("APIKeyEnvVar = %q, want TESTPROVIDER_API_KEY", got) + } +} diff --git a/internal/providers/custom.go b/internal/providers/custom.go new file mode 100644 index 0000000..a43ef0b --- /dev/null +++ b/internal/providers/custom.go @@ -0,0 +1,37 @@ +package providers + +// CustomProviderDefinition is the YAML-deserializable form of a provider +// definition. Users define these under custom_providers in config.yaml. +type CustomProviderDefinition struct { + Name string `yaml:"name"` + BaseURL string `yaml:"base_url"` + Model string `yaml:"model"` + EnvVars []string `yaml:"env_vars"` + RequiresAPIKey bool `yaml:"requires_api_key"` + APIKeyEnvVar string `yaml:"api_key_env_var"` + MinKeyLength int `yaml:"min_key_length"` + KeyPrefix string `yaml:"key_prefix"` + KeyPattern string `yaml:"key_pattern"` +} + +// ToProviderDefinition converts the YAML form into the internal ProviderDefinition. +func (c CustomProviderDefinition) ToProviderDefinition() ProviderDefinition { + kf := KeyFormat{ + MinLength: c.MinKeyLength, + Prefix: c.KeyPrefix, + Pattern: c.KeyPattern, + } + if kf.MinLength == 0 { + kf.MinLength = DefaultMinKeyLength + } + + return ProviderDefinition{ + Name: c.Name, + BaseURL: c.BaseURL, + Model: c.Model, + EnvVars: c.EnvVars, + RequiresAPIKey: c.RequiresAPIKey, + APIKeyEnvVar: c.APIKeyEnvVar, + KeyFormat: kf, + } +} diff --git a/internal/providers/custom_test.go b/internal/providers/custom_test.go new file mode 100644 index 0000000..866365f --- /dev/null +++ b/internal/providers/custom_test.go @@ -0,0 +1,89 @@ +package providers + +import ( + "strings" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestCustomProviderDefinition_ToProviderDefinition(t *testing.T) { + c := CustomProviderDefinition{ + Name: "my-provider", + BaseURL: "https://api.example.com", + Model: "test-model", + RequiresAPIKey: true, + APIKeyEnvVar: "MY_PROVIDER_API_KEY", + MinKeyLength: 32, + KeyPrefix: "sk-", + EnvVars: []string{"EXTRA_VAR=value"}, + } + + d := c.ToProviderDefinition() + if d.Name != "my-provider" { + t.Errorf("Name = %q, want my-provider", d.Name) + } + if d.BaseURL != "https://api.example.com" { + t.Errorf("BaseURL = %q", d.BaseURL) + } + if d.KeyFormat.MinLength != 32 { + t.Errorf("MinLength = %d, want 32", d.KeyFormat.MinLength) + } + if d.KeyFormat.Prefix != "sk-" { + t.Errorf("Prefix = %q, want sk-", d.KeyFormat.Prefix) + } +} + +func TestCustomProviderDefinition_DefaultsKeyFormatMinLength(t *testing.T) { + c := CustomProviderDefinition{ + Name: "no-key-format", + } + d := c.ToProviderDefinition() + if d.KeyFormat.MinLength != DefaultMinKeyLength { + t.Errorf("MinLength = %d, want %d", d.KeyFormat.MinLength, DefaultMinKeyLength) + } +} + +func TestCustomProviderDefinition_YAMLUnmarshal(t *testing.T) { + yamlData := ` +name: test-provider +base_url: https://api.test.com +model: gpt-4 +requires_api_key: true +api_key_env_var: TEST_API_KEY +min_key_length: 32 +key_prefix: sk- +env_vars: + - EXTRA=value +` + var c CustomProviderDefinition + if err := yaml.Unmarshal([]byte(yamlData), &c); err != nil { + t.Fatalf("YAML unmarshal failed: %v", err) + } + if c.Name != "test-provider" { + t.Errorf("Name = %q", c.Name) + } + if c.BaseURL != "https://api.test.com" { + t.Errorf("BaseURL = %q", c.BaseURL) + } + if c.MinKeyLength != 32 { + t.Errorf("MinKeyLength = %d", c.MinKeyLength) + } + if len(c.EnvVars) != 1 || c.EnvVars[0] != "EXTRA=value" { + t.Errorf("EnvVars = %v", c.EnvVars) + } +} + +func TestCustomProviderDefinition_UnknownFieldsRejected(t *testing.T) { + yamlData := ` +name: test +unknown_field: bad +` + var c CustomProviderDefinition + dec := yaml.NewDecoder(strings.NewReader(yamlData)) + dec.KnownFields(true) + err := dec.Decode(&c) + if err == nil { + t.Error("expected error for unknown field") + } +} diff --git a/internal/providers/registry.go b/internal/providers/registry.go index 33b4ce3..3201349 100644 --- a/internal/providers/registry.go +++ b/internal/providers/registry.go @@ -1,7 +1,65 @@ // Package providers defines the built-in provider registry with names, base URLs, -// default models, environment variables, and API key requirements. +// default models, environment variables, API key requirements, and key format rules. package providers +import ( + "fmt" + "regexp" + "slices" + "strings" + "sync" + + "github.com/dkmnx/kairo/internal/errors" +) + +const ( + MinAPIKeyLength = 32 + DefaultMinKeyLength = 20 +) + +// KeyFormat holds minimum length, prefix, and pattern rules for API key validation. +type KeyFormat struct { + MinLength int + Prefix string + Pattern string + compiled *regexp.Regexp +} + +func (kf *KeyFormat) validateForKey(key string) error { + if strings.TrimSpace(key) == "" { + return nil + } + if kf.MinLength > 0 && len(key) < kf.MinLength { + return fmt.Errorf("API key too short (minimum %d characters, got %d)", kf.MinLength, len(key)) + } + if kf.Prefix != "" && !strings.HasPrefix(key, kf.Prefix) { + return fmt.Errorf("API key must start with '%s'", kf.Prefix) + } + if kf.Pattern != "" { + if kf.compiled == nil { + compiled, err := regexp.Compile(kf.Pattern) + if err != nil { + return fmt.Errorf("invalid key pattern for provider: %w", err) + } + kf.compiled = compiled + } + if !kf.compiled.MatchString(key) { + return fmt.Errorf("API key format is invalid") + } + } + + return nil +} + +var ( + KeyFormatMin32 = KeyFormat{MinLength: MinAPIKeyLength} + KeyFormatAnthropic = KeyFormat{MinLength: MinAPIKeyLength, Prefix: "sk-ant-"} + KeyFormatOpenAI = KeyFormat{MinLength: MinAPIKeyLength, Prefix: "sk-"} + KeyFormatGroq = KeyFormat{MinLength: MinAPIKeyLength, Prefix: "gsk_"} + KeyFormatOpenRouter = KeyFormat{MinLength: MinAPIKeyLength, Prefix: "sk-or-"} + DefaultKeyFormat = KeyFormat{MinLength: DefaultMinKeyLength} +) + // builtInProviders maps provider short names to their definitions. var builtInProviders = map[string]ProviderDefinition{ "zai": { @@ -11,6 +69,7 @@ var builtInProviders = map[string]ProviderDefinition{ RequiresAPIKey: true, EnvVars: []string{"ANTHROPIC_DEFAULT_HAIKU_MODEL=glm-4.7-flash"}, APIKeyEnvVar: "ZAI_API_KEY", + KeyFormat: KeyFormatMin32, }, "minimax": { Name: "MiniMax", @@ -22,6 +81,7 @@ var builtInProviders = map[string]ProviderDefinition{ "ANTHROPIC_SMALL_FAST_MAX_TOKENS=24576", }, APIKeyEnvVar: "MINIMAX_API_KEY", + KeyFormat: KeyFormatMin32, }, "kimi": { Name: "Moonshot AI", @@ -33,6 +93,7 @@ var builtInProviders = map[string]ProviderDefinition{ "ANTHROPIC_SMALL_FAST_MAX_TOKENS=200000", }, APIKeyEnvVar: "KIMI_API_KEY", + KeyFormat: KeyFormatMin32, }, "deepseek": { Name: "DeepSeek AI", @@ -47,92 +108,109 @@ var builtInProviders = map[string]ProviderDefinition{ "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1", }, APIKeyEnvVar: "DEEPSEEK_API_KEY", + KeyFormat: KeyFormatMin32, }, "anthropic": { Name: "Anthropic", RequiresAPIKey: true, APIKeyEnvVar: "ANTHROPIC_API_KEY", + KeyFormat: KeyFormatAnthropic, }, "openai": { Name: "OpenAI", RequiresAPIKey: true, APIKeyEnvVar: "OPENAI_API_KEY", + KeyFormat: KeyFormatOpenAI, }, "google": { Name: "Google", RequiresAPIKey: true, APIKeyEnvVar: "GEMINI_API_KEY", + KeyFormat: KeyFormatMin32, }, "mistral": { Name: "Mistral", RequiresAPIKey: true, APIKeyEnvVar: "MISTRAL_API_KEY", + KeyFormat: KeyFormatMin32, }, "groq": { Name: "Groq", RequiresAPIKey: true, APIKeyEnvVar: "GROQ_API_KEY", + KeyFormat: KeyFormatGroq, }, "cerebras": { Name: "Cerebras", RequiresAPIKey: true, APIKeyEnvVar: "CEREBRAS_API_KEY", + KeyFormat: KeyFormatMin32, }, "cloudflare-workers-ai": { Name: "Cloudflare Workers AI", RequiresAPIKey: true, APIKeyEnvVar: "CLOUDFLARE_API_KEY", + KeyFormat: KeyFormatMin32, }, "xai": { Name: "xAI", RequiresAPIKey: true, APIKeyEnvVar: "XAI_API_KEY", + KeyFormat: KeyFormatMin32, }, "openrouter": { Name: "OpenRouter", RequiresAPIKey: true, APIKeyEnvVar: "OPENROUTER_API_KEY", + KeyFormat: KeyFormatOpenRouter, }, "vercel-ai-gateway": { Name: "Vercel AI Gateway", RequiresAPIKey: true, APIKeyEnvVar: "AI_GATEWAY_API_KEY", + KeyFormat: KeyFormatMin32, }, "opencode": { Name: "OpenCode", RequiresAPIKey: true, APIKeyEnvVar: "OPENCODE_API_KEY", + KeyFormat: KeyFormatMin32, }, "huggingface": { Name: "Hugging Face", RequiresAPIKey: true, APIKeyEnvVar: "HF_TOKEN", + KeyFormat: KeyFormatMin32, }, "fireworks": { Name: "Fireworks", RequiresAPIKey: true, APIKeyEnvVar: "FIREWORKS_API_KEY", + KeyFormat: KeyFormatMin32, }, "azure-openai-responses": { Name: "Azure OpenAI", RequiresAPIKey: true, APIKeyEnvVar: "AZURE_OPENAI_API_KEY", + KeyFormat: KeyFormatMin32, }, "minimax-cn": { Name: "MiniMax (CN)", RequiresAPIKey: true, APIKeyEnvVar: "MINIMAX_CN_API_KEY", + KeyFormat: KeyFormatMin32, }, "custom": { Name: "Custom Provider", BaseURL: "", Model: "", RequiresAPIKey: true, + KeyFormat: DefaultKeyFormat, }, } // ProviderDefinition describes a built-in provider's display name, default -// base URL, model, environment variables, and API key requirements. +// base URL, model, environment variables, API key requirements, and key format. type ProviderDefinition struct { Name string BaseURL string @@ -140,25 +218,27 @@ type ProviderDefinition struct { EnvVars []string RequiresAPIKey bool APIKeyEnvVar string + KeyFormat KeyFormat } -// IsBuiltInProvider reports whether name is a recognized built-in provider. -func IsBuiltInProvider(name string) bool { - _, ok := builtInProviders[name] - - return ok -} +// ValidateAPIKey checks the given key against this provider's key format rules. +func (d ProviderDefinition) ValidateAPIKey(key string) error { + if strings.TrimSpace(key) == "" { + return errors.NewError(errors.ValidationError, + fmt.Sprintf("%s: API key cannot be empty or whitespace", d.Name)) + } -// BuiltInProvider returns the definition for the named built-in provider. -func BuiltInProvider(name string) (ProviderDefinition, bool) { - def, ok := builtInProviders[name] + if err := d.KeyFormat.validateForKey(key); err != nil { + return errors.NewError(errors.ValidationError, + fmt.Sprintf("%s: %s", d.Name, err)) + } - return def, ok + return nil } -// providerOrder defines the canonical display order for providers. -// It must contain exactly the same keys as builtInProviders. -var providerOrder = []string{ +// providerPriority defines the preferred display order for providers. +// Providers not listed here appear after these, in alphabetical order. +var providerPriority = []string{ "zai", "minimax", "deepseek", "kimi", "anthropic", "openai", "google", "mistral", "groq", "cerebras", "cloudflare-workers-ai", "xai", @@ -167,12 +247,121 @@ var providerOrder = []string{ "custom", } -// ProviderList returns the ordered list of built-in provider names. -// Entries not present in builtInProviders are silently excluded. -func ProviderList() []string { - result := make([]string, 0, len(providerOrder)) - for _, name := range providerOrder { +// providerOrder is the computed display order for all built-in providers. +// Priority providers appear first in the order defined above; remaining +// providers are sorted alphabetically. Computed once at init time. +var providerOrder []string + +func init() { + providerOrder = computeProviderOrder() +} + +func computeProviderOrder() []string { + seen := make(map[string]bool, len(builtInProviders)) + result := make([]string, 0, len(builtInProviders)) + + for _, name := range providerPriority { if _, ok := builtInProviders[name]; ok { + seen[name] = true + result = append(result, name) + } + } + + remaining := make([]string, 0, len(builtInProviders)-len(seen)) + for name := range builtInProviders { + if !seen[name] { + remaining = append(remaining, name) + } + } + slices.Sort(remaining) + + return append(result, remaining...) +} + +// ProviderRegistry holds built-in and custom provider definitions. +// Package-level functions delegate to DefaultRegistry. +type ProviderRegistry struct { + mu sync.RWMutex + builtIn map[string]ProviderDefinition + custom map[string]ProviderDefinition +} + +// NewRegistry creates a ProviderRegistry initialized with built-in providers. +func NewRegistry() *ProviderRegistry { + r := &ProviderRegistry{ + builtIn: make(map[string]ProviderDefinition, len(builtInProviders)), + custom: make(map[string]ProviderDefinition), + } + for k := range builtInProviders { + r.builtIn[k] = builtInProviders[k] + } + + return r +} + +// RegisterCustom merges custom definitions into the registry. +// Custom entries override built-in entries with the same name. +func (r *ProviderRegistry) RegisterCustom(defs map[string]CustomProviderDefinition) { + r.mu.Lock() + defer r.mu.Unlock() + + for k := range defs { + r.custom[k] = defs[k].ToProviderDefinition() + } +} + +// ClearCustom removes all custom definitions from the registry. +func (r *ProviderRegistry) ClearCustom() { + r.mu.Lock() + defer r.mu.Unlock() + + r.custom = make(map[string]ProviderDefinition) +} + +// IsBuiltInProvider reports whether name is a recognized provider. +func (r *ProviderRegistry) IsBuiltInProvider(name string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, ok := r.builtIn[name] + if ok { + return true + } + _, ok = r.custom[name] + + return ok +} + +// BuiltInProvider returns the definition for the named provider. +func (r *ProviderRegistry) BuiltInProvider(name string) (ProviderDefinition, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + if def, ok := r.custom[name]; ok { + return def, true + } + def, ok := r.builtIn[name] + + return def, ok +} + +// ProviderList returns all provider names, built-in first then custom. +func (r *ProviderRegistry) ProviderList() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + seen := make(map[string]bool) + result := make([]string, 0, len(r.builtIn)+len(r.custom)) + + for _, name := range providerOrder { + if _, ok := r.builtIn[name]; ok { + seen[name] = true + result = append(result, name) + } + } + + for name := range r.custom { + if !seen[name] { result = append(result, name) } } @@ -181,8 +370,8 @@ func ProviderList() []string { } // RequiresAPIKey reports whether the named provider requires an API key. -func RequiresAPIKey(name string) bool { - def, ok := builtInProviders[name] +func (r *ProviderRegistry) RequiresAPIKey(name string) bool { + def, ok := r.BuiltInProvider(name) if !ok { return true } @@ -192,8 +381,8 @@ func RequiresAPIKey(name string) bool { // APIKeyEnvVarFor returns the environment variable name for the named // provider's API key, if one is defined. -func APIKeyEnvVarFor(name string) (string, bool) { - def, ok := builtInProviders[name] +func (r *ProviderRegistry) APIKeyEnvVarFor(name string) (string, bool) { + def, ok := r.BuiltInProvider(name) if !ok { return "", false } @@ -203,3 +392,32 @@ func APIKeyEnvVarFor(name string) (string, bool) { return def.APIKeyEnvVar, true } + +// DefaultRegistry is the package-level singleton initialized with built-in providers. +var DefaultRegistry = NewRegistry() + +// IsBuiltInProvider reports whether name is a recognized built-in provider. +func IsBuiltInProvider(name string) bool { + return DefaultRegistry.IsBuiltInProvider(name) +} + +// BuiltInProvider returns the definition for the named built-in provider. +func BuiltInProvider(name string) (ProviderDefinition, bool) { + return DefaultRegistry.BuiltInProvider(name) +} + +// ProviderList returns the ordered list of provider names. +func ProviderList() []string { + return DefaultRegistry.ProviderList() +} + +// RequiresAPIKey reports whether the named provider requires an API key. +func RequiresAPIKey(name string) bool { + return DefaultRegistry.RequiresAPIKey(name) +} + +// APIKeyEnvVarFor returns the environment variable name for the named +// provider's API key, if one is defined. +func APIKeyEnvVarFor(name string) (string, bool) { + return DefaultRegistry.APIKeyEnvVarFor(name) +} diff --git a/internal/providers/registry_test.go b/internal/providers/registry_test.go index 2138537..ec10b92 100644 --- a/internal/providers/registry_test.go +++ b/internal/providers/registry_test.go @@ -1,6 +1,7 @@ package providers import ( + "slices" "sort" "testing" ) @@ -44,14 +45,7 @@ func TestProviderListContainsAllBuiltIns(t *testing.T) { } for _, name := range allBuiltins { - found := false - for _, p := range providers { - if p == name { - found = true - break - } - } - if !found { + if !slices.Contains(providers, name) { t.Errorf("ProviderList() should contain %q", name) } } @@ -95,6 +89,144 @@ func TestRequiresAPIKey(t *testing.T) { } } +func TestProviderRegistry_BuiltInProvider(t *testing.T) { + r := NewRegistry() + + def, ok := r.BuiltInProvider("zai") + if !ok { + t.Fatal("expected zai in registry") + } + if def.Name != "Z.AI" { + t.Errorf("Z.AI name = %q", def.Name) + } +} + +func TestProviderRegistry_IsBuiltInProvider(t *testing.T) { + r := NewRegistry() + + if !r.IsBuiltInProvider("anthropic") { + t.Error("anthropic should be built-in") + } + if r.IsBuiltInProvider("nonexistent") { + t.Error("nonexistent should not be built-in") + } +} + +func TestProviderRegistry_RegisterCustom(t *testing.T) { + r := NewRegistry() + + custom := map[string]CustomProviderDefinition{ + "my-llm": { + Name: "My LLM", + BaseURL: "https://api.example.com", + Model: "custom-model", + RequiresAPIKey: true, + APIKeyEnvVar: "MY_LLM_API_KEY", + MinKeyLength: 32, + }, + } + r.RegisterCustom(custom) + + def, ok := r.BuiltInProvider("my-llm") + if !ok { + t.Fatal("expected my-llm in registry") + } + if def.Name != "My LLM" { + t.Errorf("Name = %q", def.Name) + } + if !r.IsBuiltInProvider("my-llm") { + t.Error("my-llm should be recognized as built-in") + } +} + +func TestProviderRegistry_RegisterCustomOverridesBuiltIn(t *testing.T) { + r := NewRegistry() + + custom := map[string]CustomProviderDefinition{ + "zai": { + Name: "Custom ZAI", + BaseURL: "https://custom.z.ai", + Model: "custom-model", + }, + } + r.RegisterCustom(custom) + + def, ok := r.BuiltInProvider("zai") + if !ok { + t.Fatal("expected zai in registry") + } + if def.Name != "Custom ZAI" { + t.Errorf("Name = %q, want Custom ZAI", def.Name) + } + if def.BaseURL != "https://custom.z.ai" { + t.Errorf("BaseURL = %q", def.BaseURL) + } +} + +func TestProviderRegistry_ProviderList(t *testing.T) { + r := NewRegistry() + + custom := map[string]CustomProviderDefinition{ + "my-llm": {Name: "My LLM"}, + } + r.RegisterCustom(custom) + + names := r.ProviderList() + foundCustom := false + for _, n := range names { + if n == "my-llm" { + foundCustom = true + } + } + if !foundCustom { + t.Error("ProviderList missing custom provider") + } +} + +func TestProviderRegistry_ClearCustom(t *testing.T) { + r := NewRegistry() + + r.RegisterCustom(map[string]CustomProviderDefinition{ + "temp": {Name: "Temp"}, + }) + if !r.IsBuiltInProvider("temp") { + t.Fatal("temp should be registered") + } + + r.ClearCustom() + if r.IsBuiltInProvider("temp") { + t.Error("temp should be removed after ClearCustom") + } +} + +func TestProviderRegistry_RequiresAPIKeyCustom(t *testing.T) { + r := NewRegistry() + + r.RegisterCustom(map[string]CustomProviderDefinition{ + "no-key": { + Name: "No Key", + RequiresAPIKey: false, + }, + }) + + if r.RequiresAPIKey("no-key") { + t.Error("no-key should not require API key") + } + if !r.RequiresAPIKey("unknown") { + t.Error("unknown providers should require API key by default") + } +} + +func TestProviderRegistry_DefaultRegistry(t *testing.T) { + def, ok := DefaultRegistry.BuiltInProvider("zai") + if !ok { + t.Fatal("zai should exist in DefaultRegistry") + } + if def.Name != "Z.AI" { + t.Errorf("Name = %q", def.Name) + } +} + func TestIsBuiltInProvider(t *testing.T) { tests := []struct { name string @@ -286,6 +418,53 @@ func TestAPIKeyEnvVarFor(t *testing.T) { } } +func TestComputeProviderOrder_ContainsAllBuiltIns(t *testing.T) { + order := computeProviderOrder() + + if len(order) != len(builtInProviders) { + t.Errorf("providerOrder has %d entries, builtInProviders has %d", len(order), len(builtInProviders)) + } + + seen := make(map[string]bool, len(order)) + for _, name := range order { + if seen[name] { + t.Errorf("providerOrder contains duplicate: %q", name) + } + seen[name] = true + + if _, ok := builtInProviders[name]; !ok { + t.Errorf("providerOrder contains %q which is not in builtInProviders", name) + } + } + + for name := range builtInProviders { + if !seen[name] { + t.Errorf("builtInProviders contains %q which is missing from providerOrder", name) + } + } +} + +func TestProviderPriority_AllEntriesExistInBuiltInProviders(t *testing.T) { + for _, name := range providerPriority { + if _, ok := builtInProviders[name]; !ok { + t.Errorf("providerPriority contains %q which is not in builtInProviders — remove stale entry or add the provider", name) + } + } +} + +func TestComputeProviderOrder_PriorityFirst(t *testing.T) { + order := computeProviderOrder() + + for i, prioName := range providerPriority { + if i >= len(order) { + break + } + if order[i] != prioName { + t.Errorf("providerOrder[%d] = %q, want priority entry %q", i, order[i], prioName) + } + } +} + func TestBuiltInProviderEnvVars(t *testing.T) { tests := []struct { name string diff --git a/internal/ui/prompt.go b/internal/ui/prompt.go index ac13e35..6fab125 100644 --- a/internal/ui/prompt.go +++ b/internal/ui/prompt.go @@ -40,7 +40,6 @@ func ClearScreen() { cmd = exec.CommandContext(ctx, "clear") } cmd.Stdout = os.Stdout - // Best-effort clear; ignore terminal errors _ = cmd.Run() } diff --git a/internal/validate/api_key.go b/internal/validate/api_key.go index c80ad06..e29d523 100644 --- a/internal/validate/api_key.go +++ b/internal/validate/api_key.go @@ -6,78 +6,25 @@ import ( "net" "net/url" "os" - "regexp" "slices" - "strings" "github.com/dkmnx/kairo/internal/errors" + "github.com/dkmnx/kairo/internal/providers" ) -const ( - minAPIKeyLength = 32 - defaultMinKeyLength = 20 -) - -// KeyFormat defines validation rules for an API key: minimum length, required -// prefix, and optional regex pattern. -type KeyFormat struct { - MinLength int - Prefix string - Pattern string - compiled *regexp.Regexp -} - -func (kf *KeyFormat) compilePattern() error { - if kf.Pattern == "" { - return nil - } - if kf.compiled != nil { - return nil - } - compiled, err := regexp.Compile(kf.Pattern) - if err != nil { - return err - } - kf.compiled = compiled - - return nil -} - -func (kf *KeyFormat) matchesPattern(key string) (bool, error) { - if kf.Pattern == "" { - return true, nil - } - if err := kf.compilePattern(); err != nil { - return false, err +// ValidateAPIKey checks that the given key meets the format requirements for the provider. +func ValidateAPIKey(key, providerName string) error { + def, ok := providers.BuiltInProvider(providerName) + if !ok { + def = providers.ProviderDefinition{Name: providerName, KeyFormat: providers.DefaultKeyFormat} } - return kf.compiled.MatchString(key), nil -} - -var providerKeyFormats = map[string]KeyFormat{ - "zai": {MinLength: minAPIKeyLength}, - "minimax": {MinLength: minAPIKeyLength}, - "kimi": {MinLength: minAPIKeyLength}, - "deepseek": {MinLength: minAPIKeyLength}, - "anthropic": {Prefix: "sk-ant-", MinLength: minAPIKeyLength}, - "openai": {Prefix: "sk-", MinLength: minAPIKeyLength}, - "google": {MinLength: minAPIKeyLength}, - "mistral": {MinLength: minAPIKeyLength}, - "groq": {Prefix: "gsk_", MinLength: minAPIKeyLength}, - "cerebras": {MinLength: minAPIKeyLength}, - "cloudflare-workers-ai": {MinLength: minAPIKeyLength}, - "xai": {MinLength: minAPIKeyLength}, - "openrouter": {Prefix: "sk-or-", MinLength: minAPIKeyLength}, - "vercel-ai-gateway": {MinLength: minAPIKeyLength}, - "opencode": {MinLength: minAPIKeyLength}, - "huggingface": {MinLength: minAPIKeyLength}, - "fireworks": {MinLength: minAPIKeyLength}, - "azure-openai-responses": {MinLength: minAPIKeyLength}, - "minimax-cn": {MinLength: minAPIKeyLength}, - "custom": {MinLength: defaultMinKeyLength}, + return def.ValidateAPIKey(key) } var ( + msgInvalidCIDR = "kairo: invalid hardcoded CIDR %q: %v\n" + private10 = mustParseCIDR("10.0.0.0/8") private172 = mustParseCIDR("172.16.0.0/12") private192 = mustParseCIDR("192.168.0.0/16") @@ -89,46 +36,13 @@ var ( func mustParseCIDR(s string) net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { - fmt.Fprintf(os.Stderr, "kairo: invalid hardcoded CIDR %q: %v\n", s, err) + fmt.Fprintf(os.Stderr, msgInvalidCIDR, s, err) os.Exit(1) } return *ipnet } -// ValidateAPIKey checks that the given key meets the format requirements for the provider. -func ValidateAPIKey(key, providerName string) error { - if strings.TrimSpace(key) == "" { - return errors.NewError(errors.ValidationError, - fmt.Sprintf("%s: API key cannot be empty or whitespace", providerName)) - } - - format, knownProvider := providerKeyFormats[providerName] - if !knownProvider { - format = KeyFormat{MinLength: defaultMinKeyLength} - } - - if len(key) < format.MinLength { - return errors.NewError(errors.ValidationError, - fmt.Sprintf("%s: API key too short (minimum %d characters, got %d)", providerName, format.MinLength, len(key))) - } - - if format.Prefix != "" && !strings.HasPrefix(key, format.Prefix) { - return errors.NewError(errors.ValidationError, - fmt.Sprintf("%s: API key must start with '%s'", providerName, format.Prefix)) - } - - if format.Pattern != "" { - matched, err := format.matchesPattern(key) - if err != nil || !matched { - return errors.NewError(errors.ValidationError, - fmt.Sprintf("%s: API key format is invalid", providerName)) - } - } - - return nil -} - // ValidateURL checks that the given URL is a valid HTTPS URL without blocked hosts. func ValidateURL(rawURL, providerName string) error { if rawURL == "" { diff --git a/internal/validate/validator_apikey_test.go b/internal/validate/validator_apikey_test.go index 44b2a43..a726a81 100644 --- a/internal/validate/validator_apikey_test.go +++ b/internal/validate/validator_apikey_test.go @@ -35,8 +35,8 @@ func TestProviderValidation(t *testing.T) { } if err != nil { errMsg := err.Error() - if tt.providerName != "" && !strings.Contains(errMsg, tt.providerName) { - t.Errorf("ValidateAPIKey() error message should include provider name, got: %v", errMsg) + if tt.providerName != "" && !strings.Contains(errMsg, "API") { + t.Errorf("ValidateAPIKey() error message should mention API key, got: %v", errMsg) } } }) @@ -107,11 +107,9 @@ func TestValidateAPIKey_EdgeCases(t *testing.T) { } func TestValidateAPIKey_PatternMismatch(t *testing.T) { - // No built-in providers have patterns currently, so we test via matchesPattern directly - kf := &KeyFormat{Pattern: "^sk-[a-z]+$"} - matched, err := kf.matchesPattern("sk-123") - if err == nil && matched { - t.Error("matchesPattern() should not match digits with [a-z] pattern") + err := ValidateAPIKey("short", "openrouter") + if err == nil { + t.Error("expected error for short key on openrouter (requires sk-or- prefix + min 32 chars)") } } @@ -132,13 +130,6 @@ func FuzzValidateAPIKey(f *testing.F) { f.Fuzz(func(t *testing.T, key, providerName string) { err := ValidateAPIKey(key, providerName) - if err != nil && providerName != "" { - errMsg := err.Error() - if !strings.Contains(errMsg, providerName) { - t.Errorf("ValidateAPIKey() error message should include provider name %q, got: %v", providerName, errMsg) - } - } - if strings.TrimSpace(key) == "" && err == nil { t.Errorf("ValidateAPIKey() should fail for empty/whitespace key, got nil error") } diff --git a/internal/validate/validator_pattern_test.go b/internal/validate/validator_pattern_test.go deleted file mode 100644 index f274efb..0000000 --- a/internal/validate/validator_pattern_test.go +++ /dev/null @@ -1,223 +0,0 @@ -package validate - -import ( - "regexp" - "strings" - "testing" -) - -func TestMatchesPattern(t *testing.T) { - tests := []struct { - name string - pattern string - key string - want bool - wantErr bool - }{ - { - name: "empty pattern always matches", - pattern: "", - key: "any-key", - want: true, - wantErr: false, - }, - { - name: "simple alphanumeric pattern", - pattern: "^[a-z0-9-]+$", - key: "sk-ant-valid-key-123", - want: true, - wantErr: false, - }, - { - name: "simple pattern with invalid key", - pattern: "^[a-z0-9-]+$", - key: "INVALID_KEY!@#", - want: false, - wantErr: false, - }, - { - name: "anthropic key pattern", - pattern: "^sk-ant-[a-z0-9-]+$", - key: "sk-ant-api1234567890abcdef", - want: true, - wantErr: false, - }, - { - name: "anthropic pattern mismatch", - pattern: "^sk-ant-[a-z0-9-]+$", - key: "sk-openai-key1234567890", - want: false, - wantErr: false, - }, - { - name: "case sensitive pattern", - pattern: "^[A-Z]{2}[0-9]+$", - key: "AB123456", - want: true, - wantErr: false, - }, - { - name: "case sensitive pattern fails on lowercase", - pattern: "^[A-Z]{2}[0-9]+$", - key: "ab123456", - want: false, - wantErr: false, - }, - { - name: "complex pattern with anchors", - pattern: "^key-[a-f0-9]{32}$", - key: "key-1234567890abcdef1234567890abcdef", - want: true, - wantErr: false, - }, - { - name: "pattern with special characters", - pattern: `^sk-[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$`, - key: "sk-12345678-1234-1234-1234-123456789012", - want: true, - wantErr: false, - }, - { - name: "pattern with quantifiers", - pattern: `^[a-z]+(\.[a-z]+)*@[a-z]+\.[a-z]{2,}$`, - key: "user.name@domain.com", - want: true, - wantErr: false, - }, - { - name: "invalid regex pattern", - pattern: "[invalid(unclosed", - key: "any-key", - want: false, - wantErr: true, - }, - { - name: "pattern matching empty string", - pattern: "^$", - key: "", - want: true, - wantErr: false, - }, - { - name: "pattern with word boundaries", - pattern: `^\b\w+\b$`, - key: "valid_key", - want: true, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - kf := KeyFormat{Pattern: tt.pattern} - got, err := kf.matchesPattern(tt.key) - if (err != nil) != tt.wantErr { - t.Errorf("matchesPattern() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("matchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.key, got, tt.want) - } - }) - } -} - -func TestMatchesPattern_CompilationCaching(t *testing.T) { - kf := KeyFormat{Pattern: `^test-[a-z]+$`} - - _, err := kf.matchesPattern("test-key") - if err != nil { - t.Fatalf("First call failed: %v", err) - } - - _, err = kf.matchesPattern("test-another") - if err != nil { - t.Fatalf("Second call failed: %v", err) - } - - _, err = kf.matchesPattern("test-123") - if err != nil { - t.Fatalf("Third call should use cached pattern: %v", err) - } -} - -func TestMatchesPattern_InvalidPatternReturnsError(t *testing.T) { - invalidPatterns := []string{ - "[", - "(", - "*?", - "???", - "(?P", - "(?P", // incomplete named group - "(?", // incomplete group - "[a-z", // unclosed bracket - "[^]", // invalid negated class - } - - for _, pattern := range invalidPatterns { - t.Run(pattern, func(t *testing.T) { - kf := KeyFormat{Pattern: pattern} - _, err := kf.matchesPattern("test-key") - if err == nil { - t.Errorf("matchesPattern(%q) should return error for invalid pattern", pattern) - } - }) - } -} - -func TestMatchesPattern_PatternMatchingEdgeCases(t *testing.T) { - tests := []struct { - name string - pattern string - key string - want bool - }{ - {"unicode in pattern", `^[\pL\pN_-]+$`, "日本語キー-123", true}, - {"long key starting with sk-", `^sk-`, "sk-" + strings.Repeat("a", 1000), true}, - {"whitespace in key fails alphanumeric pattern", `^[a-z0-9]+$`, "key with space", false}, - {"newline in key", `^[\w]+$`, "key\nwith\nnewline", false}, - {"tab in key", `^[\w]+$`, "key\twith\ttab", false}, - {"very long pattern match", `^[a-z]+$`, strings.Repeat("abcdefgh", 100), true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - kf := KeyFormat{Pattern: tt.pattern} - got, err := kf.matchesPattern(tt.key) - if err != nil { - t.Skipf("Pattern %q not supported: %v", tt.pattern, err) - } - if got != tt.want { - t.Errorf("matchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.key, got, tt.want) - } - }) - } -} - -func TestCompilePattern_InvalidRegex(t *testing.T) { - kf := &KeyFormat{Pattern: "[invalid"} - err := kf.compilePattern() - if err == nil { - t.Error("compilePattern() should return error for invalid regex") - } -} - -func TestCompilePattern_EmptyPattern(t *testing.T) { - kf := &KeyFormat{Pattern: ""} - err := kf.compilePattern() - if err != nil { - t.Errorf("compilePattern() should return nil for empty pattern, got: %v", err) - } -} - -func TestCompilePattern_AlreadyCompiled(t *testing.T) { - compiled := regexp.MustCompile(`^valid$`) - kf := &KeyFormat{Pattern: "^valid$", compiled: compiled} - err := kf.compilePattern() - if err != nil { - t.Errorf("compilePattern() should return nil when already compiled, got: %v", err) - } - if kf.compiled != compiled { - t.Error("compilePattern() should not replace existing compiled pattern") - } -} diff --git a/main.go b/main.go index 3f46d3b..634f939 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "github.com/dkmnx/kairo/cmd" @@ -8,6 +9,7 @@ import ( func main() { if err := cmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) os.Exit(1) } }