diff --git a/README.md b/README.md index 859009b..78a10eb 100644 --- a/README.md +++ b/README.md @@ -262,12 +262,6 @@ Lists secrets in a KV mount under a specific path in Vault. - `mount`: The mount path of the secret engine - `path`: (Optional) The path to list secrets from (defaults to root) -#### delete_secret -Delete secrets (or keys) in a KV mount under a specific path in Vault. -- `mount`: The mount path of the secret engine -- `path`: The path to the secret to delete -- `key`: (Optional) The key name to delete from the entire secret (defaults to deleting the entire secret) - #### write_secret Writes a secret to a KV mount in Vault. - `mount`: The mount path of the secret engine diff --git a/pkg/tools/kv/delete_secret.go b/pkg/tools/kv/delete_secret.go deleted file mode 100644 index 2f1c173..0000000 --- a/pkg/tools/kv/delete_secret.go +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright IBM Corp. 2025 -// SPDX-License-Identifier: MPL-2.0 - -package kv - -import ( - "context" - "fmt" - "github.com/hashicorp/vault-mcp-server/pkg/client" - "github.com/hashicorp/vault-mcp-server/pkg/utils" - "strings" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - log "github.com/sirupsen/logrus" -) - -// DeleteSecret creates a tool for deleting secrets from a Vault KV mount -func DeleteSecret(logger *log.Logger) server.ServerTool { - return server.ServerTool{ - Tool: mcp.NewTool("delete_secret", - mcp.WithToolAnnotation( - mcp.ToolAnnotation{ - DestructiveHint: utils.ToBoolPtr(true), - IdempotentHint: utils.ToBoolPtr(false), - }, - ), - mcp.WithDescription("Delete a secret from a KV mount in Vault using the specified path and mount. If you specify a key, only that key will be deleted. If no key is specified or you delete the last key, the entire secret will be deleted."), - mcp.WithString("mount", - mcp.Required(), - mcp.Description("The mount path of the secret engine. For example, if you want to delete to 'secrets/application/credentials', this should be 'secrets' without the trailing slash."), - ), - mcp.WithString("path", - mcp.Required(), - mcp.Description("The full path to delete the secret to without the mount prefix. For example, if you want to delete to 'secrets/application/credentials', this should be 'application/credentials'."), - ), - mcp.WithString("key", - mcp.DefaultString(""), - mcp.Description("A optional key in the secret to delete. If not specified, all keys in the the secret will be deleted."), - ), - ), - Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return deleteSecretHandler(ctx, req, logger) - }, - } -} - -func deleteSecretHandler(ctx context.Context, req mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { - logger.Debug("Handling delete_secret request") - - // Extract parameters - args, ok := req.Params.Arguments.(map[string]interface{}) - if !ok { - return mcp.NewToolResultError("Missing or invalid arguments format"), nil - } - - mount, err := utils.ExtractMountPath(args) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - path, ok := args["path"].(string) - if !ok || path == "" { - return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil - } - - // Can be empty to delete the entire secret - key, ok := args["key"].(string) - if !ok { - return mcp.NewToolResultError("Missing or invalid 'key' parameter"), nil - } - - logger.WithFields(log.Fields{ - "mount": mount, - "path": path, - "key": key, - }).Debug("Deleting secret") - - // Get Vault client from context - vault, err := client.GetVaultClientFromContext(ctx, logger) - if err != nil { - logger.WithError(err).Error("Failed to get Vault client") - return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil - } - - mounts, err := vault.Sys().ListMounts() - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list mounts: %v", err)), nil - } - - // Default to a v1 KV path - fullPath := fmt.Sprintf("%s/%s", mount, strings.TrimPrefix(path, "/")) - - isV2 := false - - // Check if the mount exists - if m, ok := mounts[mount+"/"]; ok { - // is it a KV v2 mount? - if m.Options["version"] == "2" { - isV2 = true - // Construct the full path for reading (KV v2 format) - fullPath = fmt.Sprintf("%s/data/%s", mount, strings.TrimPrefix(path, "/")) - } - } else { - return mcp.NewToolResultError(fmt.Sprintf("mount path '%s' does not exist. Use 'create_mount' with the type kv2 to create the mount.", mount)), nil - } - - // Read the current secret so we can update it with the new key-value pair (or replace it) - currentSecret, err := vault.Logical().Read(fullPath) - - if currentSecret == nil { - return mcp.NewToolResultError(fmt.Sprintf("no secret exists at path '%s' in mount '%s'", path, mount)), nil - } - - if isV2 { - // V2 Secrets can be marked deleted, we need to check the metadata deletion_time - if currentSecret.Data["data"] == nil { - metaData, ok := currentSecret.Data["metadata"].(map[string]interface{}) - if !ok { - return mcp.NewToolResultError("unexpected secret metadata format for v2 API"), nil - } - if metaData["deletion_time"] != nil { - return mcp.NewToolResultError(fmt.Sprintf("secret at path '%s' in mount '%s' is deleted and cannot be read.", path, mount)), nil - } - return mcp.NewToolResultError(fmt.Sprintf("no secret exists at path '%s' in mount '%s'", path, mount)), nil - } - } - - if key != "" { - - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to read secret: %v", err)), nil - } - - var secretData map[string]interface{} - var secretsMap map[string]interface{} - - if isV2 { - // V2 API structure: secret.Data["data"] contains the actual key-value pairs - data, ok := currentSecret.Data["data"].(map[string]interface{}) - if !ok { - return mcp.NewToolResultError("unexpected secret data format for v2 API"), nil - } - secretData = data - secretsMap = secretData["data"].(map[string]interface{}) - } else { - // V1 API structure: secret.Data directly contains the key-value pairs - secretData = currentSecret.Data - secretsMap = secretData - } - - // Delete the specified key from the secret - delete(secretsMap, key) - - // If we have no keys left, we should not write an empty secret - if len(secretsMap) != 0 { - // Write (or update) the secret - versionInfo, err := vault.Logical().Write(fullPath, secretData) - if err != nil { - logger.WithError(err).WithFields(log.Fields{ - "mount": mount, - "path": path, - "key": key, - "full_path": fullPath, - }).Error("Failed to write secret") - return mcp.NewToolResultError(fmt.Sprintf("Failed to write secret: %v", err)), nil - } - - successMsg := fmt.Sprintf("Successfully updated the secret, removing the key '%s' on path '%s' in mount '%s'", key, path, mount) - - // Write out the version information if available as the AI may decide on a different approach if a version is provided - if versionInfo != nil && versionInfo.Data != nil { - successMsg = fmt.Sprintf("Successfully wrote version %v of the secret to path '%s' in mount '%s' with key '%s'", versionInfo.Data["version"], path, mount, key) - } - - logger.WithFields(log.Fields{ - "mount": mount, - "path": path, - "key": key, - "v2": isV2, - }).Info("Successfully wrote secret") - - return mcp.NewToolResultText(successMsg), nil - } - - } - - // Delete the secret - _, err = vault.Logical().Delete(fullPath) - if err != nil { - logger.WithError(err).WithFields(log.Fields{ - "mount": mount, - "path": path, - "key": key, - "full_path": fullPath, - }).Error("Failed to delete secret") - return mcp.NewToolResultError(fmt.Sprintf("Failed to delete secret: %v", err)), nil - } - - successMsg := fmt.Sprintf("Successfully deleted secret at path '%s' in mount '%s'", path, mount) - - logger.WithFields(log.Fields{ - "mount": mount, - "path": path, - "key": key, - "v2": isV2, - }).Info("Successfully deleted secret") - - return mcp.NewToolResultText(successMsg), nil -} diff --git a/pkg/tools/kv/delete_secret_versions.go b/pkg/tools/kv/delete_secret_versions.go new file mode 100644 index 0000000..c3a1a16 --- /dev/null +++ b/pkg/tools/kv/delete_secret_versions.go @@ -0,0 +1,148 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// DeleteSecretVersions creates a tool for soft-deleting specific versions of a secret in a Vault KV v2 mount +func DeleteSecretVersions(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("delete_secret_versions", + mcp.WithToolAnnotation( + mcp.ToolAnnotation{ + DestructiveHint: utils.ToBoolPtr(true), + IdempotentHint: utils.ToBoolPtr(true), + }, + ), + mcp.WithDescription("Soft-delete specific versions of a secret in a KV v2 mount in Vault. The secret data is marked as deleted but can be recovered using undelete_secret. Only supported on KV v2 mounts."), + mcp.WithString("mount", + mcp.Required(), + mcp.Description("The mount path of the secret engine."), + ), + mcp.WithString("path", + mcp.Required(), + mcp.Description("The full path to the secret without the mount prefix."), + ), + mcp.WithArray("versions", + mcp.Description("An array of version numbers to soft-delete. For example: [1, 3, 5]. If not specified, the latest version is soft-deleted. Soft-deleted versions can be recovered with undelete_secret."), + ), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return deleteSecretVersionsHandler(ctx, req, logger) + }, + } +} + +func deleteSecretVersionsHandler(ctx context.Context, req mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { + logger.Debug("Handling delete_secret_versions request") + + // Extract parameters + args, ok := req.Params.Arguments.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("Missing or invalid arguments format"), nil + } + + mount, err := utils.ExtractMountPath(args) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + path, ok := args["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil + } + + // Get Vault client from context + vault, err := client.GetVaultClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("Failed to get Vault client") + return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil + } + + isV2, err := getMountInfo(vault, mount) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if !isV2 { + return mcp.NewToolResultError("delete_secret_versions is only supported on KV v2 mounts"), nil + } + + versionsRaw, hasVersions := args["versions"].([]interface{}) + + if hasVersions && len(versionsRaw) > 0 { + // Convert float64 values from JSON to int + versions := make([]int, 0, len(versionsRaw)) + for _, v := range versionsRaw { + vFloat, ok := v.(float64) + if !ok { + return mcp.NewToolResultError("Invalid version number in 'versions' array — each element must be a number"), nil + } + versions = append(versions, int(vFloat)) + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + "versions": versions, + }).Debug("Soft-deleting secret versions") + + // Soft-delete specific versions at mount/delete/path + fullPath := fmt.Sprintf("%s/delete/%s", mount, strings.TrimPrefix(path, "/")) + _, err = vault.Logical().Write(fullPath, map[string]interface{}{ + "versions": versions, + }) + if err != nil { + logger.WithError(err).WithFields(log.Fields{ + "mount": mount, + "path": path, + "full_path": fullPath, + "versions": versions, + }).Error("Failed to soft-delete secret versions") + return mcp.NewToolResultError(fmt.Sprintf("Failed to soft-delete secret versions: %v", err)), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + "versions": versions, + }).Info("Successfully soft-deleted secret versions") + + return mcp.NewToolResultText(fmt.Sprintf("Successfully soft-deleted versions %v of secret at path '%s' in mount '%s'. Use undelete_secret to recover them.", versions, path, mount)), nil + } + + // No versions specified: DELETE on the data path soft-deletes the latest version + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Debug("Soft-deleting latest secret version") + + dataPath := fmt.Sprintf("%s/data/%s", mount, strings.TrimPrefix(path, "/")) + _, err = vault.Logical().Delete(dataPath) + if err != nil { + logger.WithError(err).WithFields(log.Fields{ + "mount": mount, + "path": path, + "full_path": dataPath, + }).Error("Failed to soft-delete latest secret version") + return mcp.NewToolResultError(fmt.Sprintf("Failed to soft-delete latest secret version: %v", err)), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Info("Successfully soft-deleted latest secret version") + + return mcp.NewToolResultText(fmt.Sprintf("Successfully soft-deleted the latest version of secret at path '%s' in mount '%s'. Use undelete_secret to recover it.", path, mount)), nil +} \ No newline at end of file diff --git a/pkg/tools/kv/delete_secret_versions_test.go b/pkg/tools/kv/delete_secret_versions_test.go new file mode 100644 index 0000000..fe7a328 --- /dev/null +++ b/pkg/tools/kv/delete_secret_versions_test.go @@ -0,0 +1,49 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestDeleteSecretVersions(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := DeleteSecretVersions(logger) + + assert.Equal(t, "delete_secret_versions", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "Soft-delete") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := DeleteSecretVersions(logger) + + assert.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.True(t, *tool.Tool.Annotations.DestructiveHint) + assert.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.True(t, *tool.Tool.Annotations.IdempotentHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := DeleteSecretVersions(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + assert.NotContains(t, tool.Tool.InputSchema.Required, "versions") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := DeleteSecretVersions(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["versions"]) + }) +} diff --git a/pkg/tools/kv/destroy_secret_versions.go b/pkg/tools/kv/destroy_secret_versions.go new file mode 100644 index 0000000..acf01d8 --- /dev/null +++ b/pkg/tools/kv/destroy_secret_versions.go @@ -0,0 +1,126 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// DestroySecretVersions creates a tool for permanently destroying secret versions in a Vault KV v2 mount +func DestroySecretVersions(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("destroy_secret_versions", + mcp.WithToolAnnotation( + mcp.ToolAnnotation{ + DestructiveHint: utils.ToBoolPtr(true), + IdempotentHint: utils.ToBoolPtr(true), + }, + ), + mcp.WithDescription("Permanently destroy specific versions of a secret in a KV v2 mount in Vault. Unlike delete, destroyed versions cannot be recovered. Only supported on KV v2 mounts."), + mcp.WithString("mount", + mcp.Required(), + mcp.Description("The mount path of the secret engine."), + ), + mcp.WithString("path", + mcp.Required(), + mcp.Description("The full path to the secret without the mount prefix."), + ), + mcp.WithArray("versions", + mcp.Required(), + mcp.Description("An array of version numbers to permanently destroy. For example: [1, 3, 5]."), + ), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return destroySecretVersionsHandler(ctx, req, logger) + }, + } +} + +func destroySecretVersionsHandler(ctx context.Context, req mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { + logger.Debug("Handling destroy_secret_versions request") + + // Extract parameters + args, ok := req.Params.Arguments.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("Missing or invalid arguments format"), nil + } + + mount, err := utils.ExtractMountPath(args) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + path, ok := args["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil + } + + versionsRaw, ok := args["versions"].([]interface{}) + if !ok || len(versionsRaw) == 0 { + return mcp.NewToolResultError("Missing or invalid 'versions' parameter — must be a non-empty array of version numbers"), nil + } + + // Convert float64 values from JSON to int + versions := make([]int, 0, len(versionsRaw)) + for _, v := range versionsRaw { + vFloat, ok := v.(float64) + if !ok { + return mcp.NewToolResultError("Invalid version number in 'versions' array — each element must be a number"), nil + } + versions = append(versions, int(vFloat)) + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + "versions": versions, + }).Debug("Destroying secret versions") + + // Get Vault client from context + vault, err := client.GetVaultClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("Failed to get Vault client") + return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil + } + + isV2, err := getMountInfo(vault, mount) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if !isV2 { + return mcp.NewToolResultError("destroy_secret_versions is only supported on KV v2 mounts"), nil + } + + // Destroy at mount/destroy/path + fullPath := fmt.Sprintf("%s/destroy/%s", mount, strings.TrimPrefix(path, "/")) + _, err = vault.Logical().Write(fullPath, map[string]interface{}{ + "versions": versions, + }) + if err != nil { + logger.WithError(err).WithFields(log.Fields{ + "mount": mount, + "path": path, + "full_path": fullPath, + "versions": versions, + }).Error("Failed to destroy secret versions") + return mcp.NewToolResultError(fmt.Sprintf("Failed to destroy secret versions: %v", err)), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + "versions": versions, + }).Info("Successfully destroyed secret versions") + + return mcp.NewToolResultText(fmt.Sprintf("Successfully destroyed versions %v of secret at path '%s' in mount '%s'", versions, path, mount)), nil +} diff --git a/pkg/tools/kv/destroy_secret_versions_test.go b/pkg/tools/kv/destroy_secret_versions_test.go new file mode 100644 index 0000000..308a40f --- /dev/null +++ b/pkg/tools/kv/destroy_secret_versions_test.go @@ -0,0 +1,50 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestDestroySecretVersions(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := DestroySecretVersions(logger) + + assert.Equal(t, "destroy_secret_versions", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "Permanently destroy") + assert.Contains(t, tool.Tool.Description, "KV v2") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := DestroySecretVersions(logger) + + assert.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.True(t, *tool.Tool.Annotations.DestructiveHint) + assert.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.True(t, *tool.Tool.Annotations.IdempotentHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := DestroySecretVersions(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + assert.Contains(t, tool.Tool.InputSchema.Required, "versions") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := DestroySecretVersions(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["versions"]) + }) +} diff --git a/pkg/tools/kv/helpers.go b/pkg/tools/kv/helpers.go new file mode 100644 index 0000000..1b54215 --- /dev/null +++ b/pkg/tools/kv/helpers.go @@ -0,0 +1,31 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "fmt" + + "github.com/hashicorp/vault/api" +) + +// getMountInfo checks whether a mount exists and if it's a KV v2 mount. +// Returns isV2=true for KV v2 mounts, isV2=false for KV v1 mounts. +// Returns an error if the mount does not exist. +func getMountInfo(vault *api.Client, mount string) (isV2 bool, err error) { + mounts, err := vault.Sys().ListMounts() + if err != nil { + return false, fmt.Errorf("failed to list mounts: %v", err) + } + + m, ok := mounts[mount+"/"] + if !ok { + return false, fmt.Errorf("mount path '%s' does not exist. Use 'create_mount' with the type kv2 to create the mount", mount) + } + + if m.Options["version"] == "2" { + return true, nil + } + + return false, nil +} diff --git a/pkg/tools/kv/list_secrets_test.go b/pkg/tools/kv/list_secrets_test.go new file mode 100644 index 0000000..42fc12c --- /dev/null +++ b/pkg/tools/kv/list_secrets_test.go @@ -0,0 +1,45 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestListSecrets(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := ListSecrets(logger) + + assert.Equal(t, "list_secrets", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "List secrets") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := ListSecrets(logger) + + assert.NotNil(t, tool.Tool.Annotations.ReadOnlyHint) + assert.True(t, *tool.Tool.Annotations.ReadOnlyHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := ListSecrets(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.NotContains(t, tool.Tool.InputSchema.Required, "path") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := ListSecrets(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + }) +} diff --git a/pkg/tools/kv/patch_secret.go b/pkg/tools/kv/patch_secret.go new file mode 100644 index 0000000..d5c226e --- /dev/null +++ b/pkg/tools/kv/patch_secret.go @@ -0,0 +1,119 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// PatchSecret creates a tool for patching secrets in a Vault KV v2 mount +func PatchSecret(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("patch_secret", + mcp.WithToolAnnotation( + mcp.ToolAnnotation{ + DestructiveHint: utils.ToBoolPtr(true), + IdempotentHint: utils.ToBoolPtr(false), + }, + ), + mcp.WithDescription("Patch a secret in a KV v2 mount in Vault. Merges the provided data with the existing secret data without replacing unspecified keys. Uses HTTP PATCH with JSON merge patch semantics. Only supported on KV v2 mounts."), + mcp.WithString("mount", + mcp.Required(), + mcp.Description("The mount path of the secret engine."), + ), + mcp.WithString("path", + mcp.Required(), + mcp.Description("The full path to the secret without the mount prefix."), + ), + mcp.WithObject("data", + mcp.Required(), + mcp.Description("A key-value map of the secret data to merge. Only the specified keys will be updated; unspecified keys remain unchanged. For example: {\"password\": \"new_password\"} will update only the password key."), + ), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return patchSecretHandler(ctx, req, logger) + }, + } +} + +func patchSecretHandler(ctx context.Context, req mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { + logger.Debug("Handling patch_secret request") + + // Extract parameters + args, ok := req.Params.Arguments.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("Missing or invalid arguments format"), nil + } + + mount, err := utils.ExtractMountPath(args) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + path, ok := args["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil + } + + data, ok := args["data"].(map[string]interface{}) + if !ok || data == nil { + return mcp.NewToolResultError("Missing or invalid 'data' parameter — must be a JSON object"), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Debug("Patching secret") + + // Get Vault client from context + vault, err := client.GetVaultClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("Failed to get Vault client") + return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil + } + + isV2, err := getMountInfo(vault, mount) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if !isV2 { + return mcp.NewToolResultError("patch_secret is only supported on KV v2 mounts"), nil + } + + // Patch at mount/data/path using JSON merge patch + fullPath := fmt.Sprintf("%s/data/%s", mount, strings.TrimPrefix(path, "/")) + versionInfo, err := vault.Logical().JSONMergePatch(ctx, fullPath, map[string]interface{}{ + "data": data, + }) + if err != nil { + logger.WithError(err).WithFields(log.Fields{ + "mount": mount, + "path": path, + "full_path": fullPath, + }).Error("Failed to patch secret") + return mcp.NewToolResultError(fmt.Sprintf("Failed to patch secret: %v", err)), nil + } + + successMsg := fmt.Sprintf("Successfully patched secret at path '%s' in mount '%s'", path, mount) + + if versionInfo != nil && versionInfo.Data != nil { + successMsg = fmt.Sprintf("Successfully patched secret at path '%s' in mount '%s' (version %v)", path, mount, versionInfo.Data["version"]) + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Info("Successfully patched secret") + + return mcp.NewToolResultText(successMsg), nil +} diff --git a/pkg/tools/kv/patch_secret_test.go b/pkg/tools/kv/patch_secret_test.go new file mode 100644 index 0000000..2906d62 --- /dev/null +++ b/pkg/tools/kv/patch_secret_test.go @@ -0,0 +1,50 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestPatchSecret(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := PatchSecret(logger) + + assert.Equal(t, "patch_secret", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "Patch") + assert.Contains(t, tool.Tool.Description, "KV v2") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := PatchSecret(logger) + + assert.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.True(t, *tool.Tool.Annotations.DestructiveHint) + assert.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.False(t, *tool.Tool.Annotations.IdempotentHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := PatchSecret(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + assert.Contains(t, tool.Tool.InputSchema.Required, "data") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := PatchSecret(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["data"]) + }) +} diff --git a/pkg/tools/kv/read_secret.go b/pkg/tools/kv/read_secret.go index 160acaf..4662e97 100644 --- a/pkg/tools/kv/read_secret.go +++ b/pkg/tools/kv/read_secret.go @@ -7,11 +7,12 @@ import ( "context" "encoding/json" "fmt" - "github.com/hashicorp/vault-mcp-server/pkg/client" - "github.com/hashicorp/vault-mcp-server/pkg/utils" - + "strconv" "strings" + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" + "github.com/hashicorp/vault/api" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" log "github.com/sirupsen/logrus" @@ -30,6 +31,9 @@ func ReadSecret(logger *log.Logger) server.ServerTool { mcp.Required(), mcp.Description("The full path to read the secret to without the mount prefix. For example, if you want to read from 'secrets/application/credentials', this should be 'application/credentials'."), ), + mcp.WithNumber("version", + mcp.Description("The version of the secret to read. Only supported on KV v2 mounts. If not specified, the latest version is returned."), + ), ), Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return readSecretHandler(ctx, req, logger) @@ -56,6 +60,14 @@ func readSecretHandler(ctx context.Context, req mcp.CallToolRequest, logger *log return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil } + // Extract optional version parameter + var version int + hasVersion := false + if v, ok := args["version"].(float64); ok { + version = int(v) + hasVersion = true + } + logger.WithFields(log.Fields{ "mount": mount, "path": path, @@ -68,30 +80,31 @@ func readSecretHandler(ctx context.Context, req mcp.CallToolRequest, logger *log return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil } - mounts, err := vault.Sys().ListMounts() + isV2, err := getMountInfo(vault, mount) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list mounts: %v", err)), nil + return mcp.NewToolResultError(err.Error()), nil + } + + // Version parameter is only supported on KV v2 + if hasVersion && !isV2 { + return mcp.NewToolResultError("version parameter is only supported on KV v2 mounts"), nil } // Default to a v1 KV path fullPath := fmt.Sprintf("%s/%s", mount, strings.TrimPrefix(path, "/")) - - isV2 := false - - // Check if the mount exists - if m, ok := mounts[mount+"/"]; ok { - // is it a KV v2 mount? - if m.Options["version"] == "2" { - isV2 = true - // Construct the full path for reading (KV v2 format) - fullPath = fmt.Sprintf("%s/data/%s", mount, strings.TrimPrefix(path, "/")) - } - } else { - return mcp.NewToolResultError(fmt.Sprintf("mount path '%s' does not exist. Use 'create_mount' with the type kv2 to create the mount.", mount)), nil + if isV2 { + fullPath = fmt.Sprintf("%s/data/%s", mount, strings.TrimPrefix(path, "/")) } // Read the secret - secret, err := vault.Logical().Read(fullPath) + var secret *api.Secret + if hasVersion && isV2 { + secret, err = vault.Logical().ReadWithData(fullPath, map[string][]string{ + "version": {strconv.Itoa(version)}, + }) + } else { + secret, err = vault.Logical().Read(fullPath) + } if err != nil { logger.WithError(err).WithFields(log.Fields{ "mount": mount, diff --git a/pkg/tools/kv/read_secret_metadata.go b/pkg/tools/kv/read_secret_metadata.go new file mode 100644 index 0000000..2b44d7c --- /dev/null +++ b/pkg/tools/kv/read_secret_metadata.go @@ -0,0 +1,114 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// ReadSecretMetadata creates a tool for reading secret metadata from a Vault KV v2 mount +func ReadSecretMetadata(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("read_secret_metadata", + mcp.WithToolAnnotation( + mcp.ToolAnnotation{ + ReadOnlyHint: utils.ToBoolPtr(true), + IdempotentHint: utils.ToBoolPtr(true), + }, + ), + mcp.WithDescription("Read metadata for a secret from a KV v2 mount in Vault. Returns version history, custom metadata, and configuration like max_versions and cas_required. Only supported on KV v2 mounts."), + mcp.WithString("mount", + mcp.Required(), + mcp.Description("The mount path of the secret engine. For example, if you want to read metadata from 'secrets/application/credentials', this should be 'secrets' without the trailing slash."), + ), + mcp.WithString("path", + mcp.Required(), + mcp.Description("The full path to the secret without the mount prefix. For example, if you want to read metadata from 'secrets/application/credentials', this should be 'application/credentials'."), + ), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return readSecretMetadataHandler(ctx, req, logger) + }, + } +} + +func readSecretMetadataHandler(ctx context.Context, req mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { + logger.Debug("Handling read_secret_metadata request") + + // Extract parameters + args, ok := req.Params.Arguments.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("Missing or invalid arguments format"), nil + } + + mount, err := utils.ExtractMountPath(args) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + path, ok := args["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Debug("Reading secret metadata") + + // Get Vault client from context + vault, err := client.GetVaultClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("Failed to get Vault client") + return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil + } + + isV2, err := getMountInfo(vault, mount) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if !isV2 { + return mcp.NewToolResultError("read_secret_metadata is only supported on KV v2 mounts"), nil + } + + // Read metadata at mount/metadata/path + fullPath := fmt.Sprintf("%s/metadata/%s", mount, strings.TrimPrefix(path, "/")) + secret, err := vault.Logical().Read(fullPath) + if err != nil { + logger.WithError(err).WithFields(log.Fields{ + "mount": mount, + "path": path, + "full_path": fullPath, + }).Error("Failed to read secret metadata") + return mcp.NewToolResultError(fmt.Sprintf("Failed to read secret metadata: %v", err)), nil + } + + if secret == nil { + return mcp.NewToolResultError(fmt.Sprintf("No metadata found at path '%s' in mount '%s'", path, mount)), nil + } + + // Marshal to JSON + jsonData, err := json.Marshal(secret.Data) + if err != nil { + logger.WithError(err).Error("Failed to marshal metadata to JSON") + return mcp.NewToolResultError(fmt.Sprintf("Error marshaling JSON: %v", err)), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Debug("Successfully read secret metadata") + + return mcp.NewToolResultText(string(jsonData)), nil +} diff --git a/pkg/tools/kv/read_secret_metadata_test.go b/pkg/tools/kv/read_secret_metadata_test.go new file mode 100644 index 0000000..3396a90 --- /dev/null +++ b/pkg/tools/kv/read_secret_metadata_test.go @@ -0,0 +1,48 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestReadSecretMetadata(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := ReadSecretMetadata(logger) + + assert.Equal(t, "read_secret_metadata", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "metadata") + assert.Contains(t, tool.Tool.Description, "KV v2") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := ReadSecretMetadata(logger) + + assert.NotNil(t, tool.Tool.Annotations.ReadOnlyHint) + assert.True(t, *tool.Tool.Annotations.ReadOnlyHint) + assert.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.True(t, *tool.Tool.Annotations.IdempotentHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := ReadSecretMetadata(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := ReadSecretMetadata(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + }) +} diff --git a/pkg/tools/kv/read_secret_test.go b/pkg/tools/kv/read_secret_test.go new file mode 100644 index 0000000..884ecf0 --- /dev/null +++ b/pkg/tools/kv/read_secret_test.go @@ -0,0 +1,40 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestReadSecret(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := ReadSecret(logger) + + assert.Equal(t, "read_secret", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "Read a secret") + assert.NotNil(t, tool.Handler) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := ReadSecret(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + assert.NotContains(t, tool.Tool.InputSchema.Required, "version") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := ReadSecret(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["version"]) + }) +} diff --git a/pkg/tools/kv/undelete_secret.go b/pkg/tools/kv/undelete_secret.go new file mode 100644 index 0000000..2b28165 --- /dev/null +++ b/pkg/tools/kv/undelete_secret.go @@ -0,0 +1,126 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// UndeleteSecretVersions creates a tool for undeleting secret versions in a Vault KV v2 mount +func UndeleteSecretVersions(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("undelete_secret_versions", + mcp.WithToolAnnotation( + mcp.ToolAnnotation{ + DestructiveHint: utils.ToBoolPtr(false), + IdempotentHint: utils.ToBoolPtr(true), + }, + ), + mcp.WithDescription("Undelete (restore) previously soft-deleted versions of a secret in a KV v2 mount in Vault. Only supported on KV v2 mounts."), + mcp.WithString("mount", + mcp.Required(), + mcp.Description("The mount path of the secret engine."), + ), + mcp.WithString("path", + mcp.Required(), + mcp.Description("The full path to the secret without the mount prefix."), + ), + mcp.WithArray("versions", + mcp.Required(), + mcp.Description("An array of version numbers to undelete. For example: [1, 3, 5]."), + ), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return undeleteSecretHandler(ctx, req, logger) + }, + } +} + +func undeleteSecretHandler(ctx context.Context, req mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { + logger.Debug("Handling undelete_secret request") + + // Extract parameters + args, ok := req.Params.Arguments.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("Missing or invalid arguments format"), nil + } + + mount, err := utils.ExtractMountPath(args) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + path, ok := args["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil + } + + versionsRaw, ok := args["versions"].([]interface{}) + if !ok || len(versionsRaw) == 0 { + return mcp.NewToolResultError("Missing or invalid 'versions' parameter — must be a non-empty array of version numbers"), nil + } + + // Convert float64 values from JSON to int + versions := make([]int, 0, len(versionsRaw)) + for _, v := range versionsRaw { + vFloat, ok := v.(float64) + if !ok { + return mcp.NewToolResultError("Invalid version number in 'versions' array — each element must be a number"), nil + } + versions = append(versions, int(vFloat)) + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + "versions": versions, + }).Debug("Undeleting secret versions") + + // Get Vault client from context + vault, err := client.GetVaultClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("Failed to get Vault client") + return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil + } + + isV2, err := getMountInfo(vault, mount) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if !isV2 { + return mcp.NewToolResultError("undelete_secret is only supported on KV v2 mounts"), nil + } + + // Undelete at mount/undelete/path + fullPath := fmt.Sprintf("%s/undelete/%s", mount, strings.TrimPrefix(path, "/")) + _, err = vault.Logical().Write(fullPath, map[string]interface{}{ + "versions": versions, + }) + if err != nil { + logger.WithError(err).WithFields(log.Fields{ + "mount": mount, + "path": path, + "full_path": fullPath, + "versions": versions, + }).Error("Failed to undelete secret versions") + return mcp.NewToolResultError(fmt.Sprintf("Failed to undelete secret versions: %v", err)), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + "versions": versions, + }).Info("Successfully undeleted secret versions") + + return mcp.NewToolResultText(fmt.Sprintf("Successfully undeleted versions %v of secret at path '%s' in mount '%s'", versions, path, mount)), nil +} diff --git a/pkg/tools/kv/undelete_secret_test.go b/pkg/tools/kv/undelete_secret_test.go new file mode 100644 index 0000000..037da46 --- /dev/null +++ b/pkg/tools/kv/undelete_secret_test.go @@ -0,0 +1,50 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestUndeleteSecretVersions(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := UndeleteSecretVersions(logger) + + assert.Equal(t, "undelete_secret_versions", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "Undelete") + assert.Contains(t, tool.Tool.Description, "KV v2") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := UndeleteSecretVersions(logger) + + assert.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.False(t, *tool.Tool.Annotations.DestructiveHint) + assert.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.True(t, *tool.Tool.Annotations.IdempotentHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := UndeleteSecretVersions(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + assert.Contains(t, tool.Tool.InputSchema.Required, "versions") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := UndeleteSecretVersions(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["versions"]) + }) +} diff --git a/pkg/tools/kv/write_secret.go b/pkg/tools/kv/write_secret.go index d035044..ea4c79e 100644 --- a/pkg/tools/kv/write_secret.go +++ b/pkg/tools/kv/write_secret.go @@ -6,11 +6,10 @@ package kv import ( "context" "fmt" - "github.com/hashicorp/vault-mcp-server/pkg/client" - "github.com/hashicorp/vault-mcp-server/pkg/utils" - "strings" + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" log "github.com/sirupsen/logrus" @@ -26,7 +25,7 @@ func WriteSecret(logger *log.Logger) server.ServerTool { IdempotentHint: utils.ToBoolPtr(false), // We are not idempotent because writing a secret will always create a new version on the kv2 }, ), - mcp.WithDescription("Writes a secret value to a KV store in Vault using the specified path and mount. Supports both KV v1 and v2 mounts. If a KV v2 mount is detected, the currently stored version of the secret will be returned."), + mcp.WithDescription("Writes a secret to a KV store in Vault using the specified path and mount. The data parameter is a complete key-value map that will replace the secret at the given path. Supports both KV v1 and v2 mounts. If a KV v2 mount is detected, the currently stored version of the secret will be returned."), mcp.WithString("mount", mcp.Required(), mcp.Description("The mount path of the secret engine. For example, if you want to write to 'secrets/application/credentials', this should be 'secrets' without the trailing slash."), @@ -35,13 +34,9 @@ func WriteSecret(logger *log.Logger) server.ServerTool { mcp.Required(), mcp.Description("The full path to write the secret to without the mount prefix. For example, if you want to write to 'secrets/application/credentials', this should be 'application/credentials'."), ), - mcp.WithString("key", + mcp.WithObject("data", mcp.Required(), - mcp.Description("The key name for the secret. For example if you want to write mysecret=myvalue, this should be 'mysecret'"), - ), - mcp.WithString("value", - mcp.Required(), - mcp.Description("The value to store the given key. For example if you want to write mysecret=myvalue, this should be 'myvalue'"), + mcp.Description("A complete key-value map of the secret data to write. For example: {\"username\": \"admin\", \"password\": \"s3cret\"}. This will replace the entire secret at the given path."), ), ), Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -69,20 +64,14 @@ func writeSecretHandler(ctx context.Context, req mcp.CallToolRequest, logger *lo return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil } - key, ok := args["key"].(string) - if !ok || key == "" { - return mcp.NewToolResultError("Missing or invalid 'key' parameter"), nil - } - - value, ok := args["value"].(string) - if !ok || value == "" { - return mcp.NewToolResultError("Missing or invalid 'value' parameter"), nil + data, ok := args["data"].(map[string]interface{}) + if !ok || data == nil { + return mcp.NewToolResultError("Missing or invalid 'data' parameter — must be a JSON object"), nil } logger.WithFields(log.Fields{ "mount": mount, "path": path, - "key": key, }).Debug("Writing secret") // Get Vault client from context @@ -92,77 +81,48 @@ func writeSecretHandler(ctx context.Context, req mcp.CallToolRequest, logger *lo return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil } - mounts, err := vault.Sys().ListMounts() + isV2, err := getMountInfo(vault, mount) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list mounts: %v", err)), nil + return mcp.NewToolResultError(err.Error()), nil } - // Default to a v1 KV path + // Construct the full path fullPath := fmt.Sprintf("%s/%s", mount, strings.TrimPrefix(path, "/")) - - isV2 := false - - // Check if the mount exists - if m, ok := mounts[mount+"/"]; ok { - // is it a KV v2 mount? - if m.Options["version"] == "2" { - isV2 = true - // Construct the full path for reading (KV v2 format) - fullPath = fmt.Sprintf("%s/data/%s", mount, strings.TrimPrefix(path, "/")) - } - } else { - return mcp.NewToolResultError(fmt.Sprintf("mount path '%s' does not exist. Use 'create_mount' with the type kv2 to create the mount.", mount)), nil - } - - // Read the current secret so we can update it with the new key-value pair (or replace it) - currentSecret, err := vault.Logical().Read(fullPath) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to read secret: %v", err)), nil - } - - var secretData map[string]interface{} - - if currentSecret != nil { - secretData = currentSecret.Data + if isV2 { + fullPath = fmt.Sprintf("%s/data/%s", mount, strings.TrimPrefix(path, "/")) } + // Prepare the write data + var writeData map[string]interface{} if isV2 { - if secretData == nil { - secretData = map[string]interface{}{ - "data": make(map[string]interface{}), - } + writeData = map[string]interface{}{ + "data": data, } - secretData["data"].(map[string]interface{})[key] = value } else { - if secretData == nil { - secretData = map[string]interface{}{} - } - secretData[key] = value + writeData = data } - // Write (or update) the secret - versionInfo, err := vault.Logical().Write(fullPath, secretData) + // Write the secret + versionInfo, err := vault.Logical().Write(fullPath, writeData) if err != nil { logger.WithError(err).WithFields(log.Fields{ "mount": mount, "path": path, - "key": key, "full_path": fullPath, }).Error("Failed to write secret") return mcp.NewToolResultError(fmt.Sprintf("Failed to write secret: %v", err)), nil } - successMsg := fmt.Sprintf("Successfully updated the secret, adding or updating the key '%s' on path '%s' in mount '%s'", key, path, mount) + successMsg := fmt.Sprintf("Successfully wrote secret to path '%s' in mount '%s'", path, mount) - // Write out the version information if available as the AI may decide on a different approach if a version is provided + // Write out the version information if available if versionInfo != nil && versionInfo.Data != nil { - successMsg = fmt.Sprintf("Successfully wrote version %v of the secret to path '%s' in mount '%s' with key '%s'", versionInfo.Data["version"], path, mount, key) + successMsg = fmt.Sprintf("Successfully wrote version %v of the secret to path '%s' in mount '%s'", versionInfo.Data["version"], path, mount) } logger.WithFields(log.Fields{ "mount": mount, "path": path, - "key": key, "v2": isV2, }).Info("Successfully wrote secret") diff --git a/pkg/tools/kv/write_secret_metadata.go b/pkg/tools/kv/write_secret_metadata.go new file mode 100644 index 0000000..e4f8081 --- /dev/null +++ b/pkg/tools/kv/write_secret_metadata.go @@ -0,0 +1,134 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/vault-mcp-server/pkg/client" + "github.com/hashicorp/vault-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// WriteSecretMetadata creates a tool for writing secret metadata to a Vault KV v2 mount +func WriteSecretMetadata(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("write_secret_metadata", + mcp.WithToolAnnotation( + mcp.ToolAnnotation{ + DestructiveHint: utils.ToBoolPtr(true), + IdempotentHint: utils.ToBoolPtr(false), + }, + ), + mcp.WithDescription("Write metadata configuration for a secret in a KV v2 mount in Vault. Allows setting max_versions, cas_required, delete_version_after, and custom_metadata. Only supported on KV v2 mounts."), + mcp.WithString("mount", + mcp.Required(), + mcp.Description("The mount path of the secret engine."), + ), + mcp.WithString("path", + mcp.Required(), + mcp.Description("The full path to the secret without the mount prefix."), + ), + mcp.WithNumber("max_versions", + mcp.Description("The maximum number of versions to keep for the secret. If not set, the backend's configured max version is used."), + ), + mcp.WithBoolean("cas_required", + mcp.Description("If true, the backend will require the cas parameter to be set on every write."), + ), + mcp.WithString("delete_version_after", + mcp.Description("The duration after which a version is deleted. Accepts Go duration format (e.g. '3h25m', '72h')."), + ), + mcp.WithObject("custom_metadata", + mcp.Description("A map of arbitrary string key-value pairs to store as custom metadata for the secret."), + ), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return writeSecretMetadataHandler(ctx, req, logger) + }, + } +} + +func writeSecretMetadataHandler(ctx context.Context, req mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { + logger.Debug("Handling write_secret_metadata request") + + // Extract parameters + args, ok := req.Params.Arguments.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("Missing or invalid arguments format"), nil + } + + mount, err := utils.ExtractMountPath(args) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + path, ok := args["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("Missing or invalid 'path' parameter"), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Debug("Writing secret metadata") + + // Get Vault client from context + vault, err := client.GetVaultClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("Failed to get Vault client") + return mcp.NewToolResultError(fmt.Sprintf("Failed to get Vault client: %v", err)), nil + } + + isV2, err := getMountInfo(vault, mount) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if !isV2 { + return mcp.NewToolResultError("write_secret_metadata is only supported on KV v2 mounts"), nil + } + + // Build data map only with provided parameters + data := make(map[string]interface{}) + + if v, ok := args["max_versions"].(float64); ok { + data["max_versions"] = int(v) + } + if v, ok := args["cas_required"].(bool); ok { + data["cas_required"] = v + } + if v, ok := args["delete_version_after"].(string); ok { + data["delete_version_after"] = v + } + if v, ok := args["custom_metadata"].(map[string]interface{}); ok { + data["custom_metadata"] = v + } + + if len(data) == 0 { + return mcp.NewToolResultError("At least one metadata field must be provided (max_versions, cas_required, delete_version_after, or custom_metadata)"), nil + } + + // Write metadata at mount/metadata/path + fullPath := fmt.Sprintf("%s/metadata/%s", mount, strings.TrimPrefix(path, "/")) + _, err = vault.Logical().Write(fullPath, data) + if err != nil { + logger.WithError(err).WithFields(log.Fields{ + "mount": mount, + "path": path, + "full_path": fullPath, + }).Error("Failed to write secret metadata") + return mcp.NewToolResultError(fmt.Sprintf("Failed to write secret metadata: %v", err)), nil + } + + logger.WithFields(log.Fields{ + "mount": mount, + "path": path, + }).Info("Successfully wrote secret metadata") + + return mcp.NewToolResultText(fmt.Sprintf("Successfully wrote metadata for secret at path '%s' in mount '%s'", path, mount)), nil +} diff --git a/pkg/tools/kv/write_secret_metadata_test.go b/pkg/tools/kv/write_secret_metadata_test.go new file mode 100644 index 0000000..35cae8c --- /dev/null +++ b/pkg/tools/kv/write_secret_metadata_test.go @@ -0,0 +1,57 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestWriteSecretMetadata(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := WriteSecretMetadata(logger) + + assert.Equal(t, "write_secret_metadata", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "metadata") + assert.Contains(t, tool.Tool.Description, "KV v2") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := WriteSecretMetadata(logger) + + assert.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.True(t, *tool.Tool.Annotations.DestructiveHint) + assert.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.False(t, *tool.Tool.Annotations.IdempotentHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := WriteSecretMetadata(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + // Optional params should not be required + assert.NotContains(t, tool.Tool.InputSchema.Required, "max_versions") + assert.NotContains(t, tool.Tool.InputSchema.Required, "cas_required") + assert.NotContains(t, tool.Tool.InputSchema.Required, "delete_version_after") + assert.NotContains(t, tool.Tool.InputSchema.Required, "custom_metadata") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := WriteSecretMetadata(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["max_versions"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["cas_required"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["delete_version_after"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["custom_metadata"]) + }) +} diff --git a/pkg/tools/kv/write_secret_test.go b/pkg/tools/kv/write_secret_test.go new file mode 100644 index 0000000..2de6768 --- /dev/null +++ b/pkg/tools/kv/write_secret_test.go @@ -0,0 +1,52 @@ +// Copyright IBM Corp. 2025 +// SPDX-License-Identifier: MPL-2.0 + +package kv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" +) + +func TestWriteSecret(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("tool creation", func(t *testing.T) { + tool := WriteSecret(logger) + + assert.Equal(t, "write_secret", tool.Tool.Name) + assert.Contains(t, tool.Tool.Description, "Writes a secret") + assert.NotNil(t, tool.Handler) + }) + + t.Run("annotations", func(t *testing.T) { + tool := WriteSecret(logger) + + assert.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.True(t, *tool.Tool.Annotations.DestructiveHint) + assert.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.False(t, *tool.Tool.Annotations.IdempotentHint) + }) + + t.Run("required parameters", func(t *testing.T) { + tool := WriteSecret(logger) + + assert.Contains(t, tool.Tool.InputSchema.Required, "mount") + assert.Contains(t, tool.Tool.InputSchema.Required, "path") + assert.Contains(t, tool.Tool.InputSchema.Required, "data") + }) + + t.Run("properties exist", func(t *testing.T) { + tool := WriteSecret(logger) + + assert.NotNil(t, tool.Tool.InputSchema.Properties["mount"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["path"]) + assert.NotNil(t, tool.Tool.InputSchema.Properties["data"]) + // key and value should no longer exist + assert.Nil(t, tool.Tool.InputSchema.Properties["key"]) + assert.Nil(t, tool.Tool.InputSchema.Properties["value"]) + }) +} diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index 0a92838..33c8c59 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -33,8 +33,23 @@ func InitTools(hcServer *server.MCPServer, logger *log.Logger) { writeSecretTool := kv.WriteSecret(logger) hcServer.AddTool(writeSecretTool.Tool, writeSecretTool.Handler) - deleteSecretTool := kv.DeleteSecret(logger) - hcServer.AddTool(deleteSecretTool.Tool, deleteSecretTool.Handler) + deleteSecretVersionsTool := kv.DeleteSecretVersions(logger) + hcServer.AddTool(deleteSecretVersionsTool.Tool, deleteSecretVersionsTool.Handler) + + readSecretMetadataTool := kv.ReadSecretMetadata(logger) + hcServer.AddTool(readSecretMetadataTool.Tool, readSecretMetadataTool.Handler) + + writeSecretMetadataTool := kv.WriteSecretMetadata(logger) + hcServer.AddTool(writeSecretMetadataTool.Tool, writeSecretMetadataTool.Handler) + + undeleteSecretVersionsTool := kv.UndeleteSecretVersions(logger) + hcServer.AddTool(undeleteSecretVersionsTool.Tool, undeleteSecretVersionsTool.Handler) + + destroySecretVersionsTool := kv.DestroySecretVersions(logger) + hcServer.AddTool(destroySecretVersionsTool.Tool, destroySecretVersionsTool.Handler) + + patchSecretTool := kv.PatchSecret(logger) + hcServer.AddTool(patchSecretTool.Tool, patchSecretTool.Handler) // Tools for PKI management enablePkiTool := pki.EnablePki(logger)