-
Notifications
You must be signed in to change notification settings - Fork 22
feat: add toolsets for filterting tools #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,7 @@ import ( | |||||||||||||||||||
|
|
||||||||||||||||||||
| "github.com/hashicorp/vault-mcp-server/pkg/client" | ||||||||||||||||||||
| "github.com/hashicorp/vault-mcp-server/pkg/tools" | ||||||||||||||||||||
| "github.com/hashicorp/vault-mcp-server/pkg/toolsets" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| "github.com/hashicorp/vault-mcp-server/version" | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -45,7 +46,7 @@ var ( | |||||||||||||||||||
| Use: "stdio", | ||||||||||||||||||||
| Short: "Start stdio server", | ||||||||||||||||||||
| Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, | ||||||||||||||||||||
| Run: func(_ *cobra.Command, _ []string) { | ||||||||||||||||||||
| Run: func(cmd *cobra.Command, _ []string) { | ||||||||||||||||||||
| logFile, err := rootCmd.PersistentFlags().GetString("log-file") | ||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||
| stdlog.Fatal("Failed to get log file:", err) | ||||||||||||||||||||
|
|
@@ -55,7 +56,9 @@ var ( | |||||||||||||||||||
| stdlog.Fatal("Failed to initialize logger:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runStdioServer(logger); err != nil { | ||||||||||||||||||||
| enabledToolsets := getToolsetsFromCmd(cmd.Root(), logger) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runStdioServer(logger, enabledToolsets); err != nil { | ||||||||||||||||||||
| stdlog.Fatal("failed to run stdio server:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| }, | ||||||||||||||||||||
|
|
@@ -91,7 +94,9 @@ You can specify the host, port, and endpoint path to customize where the server | |||||||||||||||||||
| stdlog.Fatal("Failed to get endpoint path:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runHTTPServer(logger, host, port, endpointPath); err != nil { | ||||||||||||||||||||
| enabledToolsets := getToolsetsFromCmd(cmd.Root(), logger) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runHTTPServer(logger, host, port, endpointPath, enabledToolsets); err != nil { | ||||||||||||||||||||
| stdlog.Fatal("failed to run streamableHTTP server:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| }, | ||||||||||||||||||||
|
|
@@ -110,12 +115,12 @@ You can specify the host, port, and endpoint path to customize where the server | |||||||||||||||||||
| } | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| func runHTTPServer(logger *log.Logger, host string, port string, endpointPath string) error { | ||||||||||||||||||||
| func runHTTPServer(logger *log.Logger, host string, port string, endpointPath string, enabledToolsets []string) error { | ||||||||||||||||||||
| ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) | ||||||||||||||||||||
| defer stop() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| hcServer := NewServer(version.Version, logger) | ||||||||||||||||||||
| tools.InitTools(hcServer, logger) | ||||||||||||||||||||
| tools.RegisterTools(hcServer, logger, enabledToolsets) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return httpServerInit(ctx, hcServer, logger, host, port, endpointPath) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -227,12 +232,12 @@ func httpServerInit(ctx context.Context, hcServer *server.MCPServer, logger *log | |||||||||||||||||||
| return nil | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| func runStdioServer(logger *log.Logger) error { | ||||||||||||||||||||
| func runStdioServer(logger *log.Logger, enabledToolsets []string) error { | ||||||||||||||||||||
| ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) | ||||||||||||||||||||
| defer stop() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| hcServer := NewServer(version.Version, logger) | ||||||||||||||||||||
| tools.InitTools(hcServer, logger) | ||||||||||||||||||||
| tools.RegisterTools(hcServer, logger, enabledToolsets) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return serverInit(ctx, hcServer, logger) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -283,14 +288,16 @@ func runDefaultCommand(cmd *cobra.Command, _ []string) { | |||||||||||||||||||
| stdlog.Fatal("Failed to initialize logger:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runStdioServer(logger); err != nil { | ||||||||||||||||||||
| enabledToolsets := getToolsetsFromCmd(cmd, logger) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runStdioServer(logger, enabledToolsets); err != nil { | ||||||||||||||||||||
| stdlog.Fatal("failed to run stdio server:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| func main() { | ||||||||||||||||||||
| // Check environment variables first - they override command line args | ||||||||||||||||||||
| if shouldUseHTTPMode() { | ||||||||||||||||||||
| if shouldUseStreamableHTTPMode() { | ||||||||||||||||||||
| port := getHTTPPort() | ||||||||||||||||||||
| host := getHTTPHost() | ||||||||||||||||||||
| endpointPath := getEndpointPath(nil) | ||||||||||||||||||||
|
|
@@ -301,8 +308,10 @@ func main() { | |||||||||||||||||||
| stdlog.Fatal("Failed to initialize logger:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runHTTPServer(logger, host, port, endpointPath); err != nil { | ||||||||||||||||||||
| stdlog.Fatal("failed to run HTTP server:", err) | ||||||||||||||||||||
| enabledToolsets := getToolsetsFromCmd(rootCmd, logger) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err := runHTTPServer(logger, host, port, endpointPath, enabledToolsets); err != nil { | ||||||||||||||||||||
| stdlog.Fatal("failed to run StreamableHTTP server:", err) | ||||||||||||||||||||
| } | ||||||||||||||||||||
| return | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -314,8 +323,8 @@ func main() { | |||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // shouldUseHTTPMode checks if environment variables indicate HTTP mode | ||||||||||||||||||||
| func shouldUseHTTPMode() bool { | ||||||||||||||||||||
| // shouldUseStreamableHTTPMode checks if environment variables indicate HTTP mode | ||||||||||||||||||||
| func shouldUseStreamableHTTPMode() bool { | ||||||||||||||||||||
| transportMode := os.Getenv("TRANSPORT_MODE") | ||||||||||||||||||||
| return transportMode == "http" || transportMode == "streamable-http" || | ||||||||||||||||||||
| os.Getenv("TRANSPORT_PORT") != "" || | ||||||||||||||||||||
|
|
@@ -339,7 +348,74 @@ func getHTTPHost() string { | |||||||||||||||||||
| return DefaultBindAddress | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Add function to get endpoint path from environment or flag | ||||||||||||||||||||
| // parseToolsets parses and validates the toolsets flag value | ||||||||||||||||||||
| func parseToolsets(toolsetsFlag string, logger *log.Logger) []string { | ||||||||||||||||||||
| rawToolsets := strings.Split(toolsetsFlag, ",") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| cleaned, invalid := toolsets.CleanToolsets(rawToolsets) | ||||||||||||||||||||
| if len(invalid) > 0 { | ||||||||||||||||||||
| logger.Warnf("Invalid toolsets ignored: %v", invalid) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| expanded := toolsets.ExpandDefaultToolset(cleaned) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| logger.Infof("Enabled toolsets: %v", expanded) | ||||||||||||||||||||
| return expanded | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // parseIndividualTools parses and validates the tools flag value | ||||||||||||||||||||
| func parseIndividualTools(toolsFlag string, logger *log.Logger) []string { | ||||||||||||||||||||
| rawTools := strings.Split(toolsFlag, ",") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| validTools, invalidTools := toolsets.ParseIndividualTools(rawTools) | ||||||||||||||||||||
| if len(invalidTools) > 0 { | ||||||||||||||||||||
| logger.Warnf("Invalid tool names ignored: %v", invalidTools) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if len(validTools) == 0 { | ||||||||||||||||||||
| logger.Warn("No valid tools specified, falling back to default toolsets") | ||||||||||||||||||||
| return parseToolsets("default", logger) | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // Use the public API to enable individual tools mode | ||||||||||||||||||||
| result := toolsets.EnableIndividualTools(validTools) | ||||||||||||||||||||
| logger.Infof("Enabled individual tools: %v", validTools) | ||||||||||||||||||||
| return result | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| func getToolsetsFromCmd(cmd *cobra.Command, logger *log.Logger) []string { | ||||||||||||||||||||
| // Check if --tools flag is set (individual tool mode) | ||||||||||||||||||||
| toolsFlag, err := cmd.Flags().GetString("tools") | ||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||
| // Try root persistent flags | ||||||||||||||||||||
| toolsFlag, err = cmd.Root().PersistentFlags().GetString("tools") | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if err == nil && toolsFlag != "" { | ||||||||||||||||||||
| // Ensure --toolsets is not also set | ||||||||||||||||||||
| toolsetsFlag, _ := cmd.Flags().GetString("toolsets") | ||||||||||||||||||||
| if toolsetsFlag == "" { | ||||||||||||||||||||
| toolsetsFlag, _ = cmd.Root().PersistentFlags().GetString("toolsets") | ||||||||||||||||||||
| } | ||||||||||||||||||||
| if toolsetsFlag != "" && toolsetsFlag != "default" { | ||||||||||||||||||||
|
Comment on lines
+395
to
+400
|
||||||||||||||||||||
| // Ensure --toolsets is not also set | |
| toolsetsFlag, _ := cmd.Flags().GetString("toolsets") | |
| if toolsetsFlag == "" { | |
| toolsetsFlag, _ = cmd.Root().PersistentFlags().GetString("toolsets") | |
| } | |
| if toolsetsFlag != "" && toolsetsFlag != "default" { | |
| // Ensure --toolsets is not also explicitly set | |
| toolsetsSet := cmd.Flags().Changed("toolsets") || cmd.Root().PersistentFlags().Changed("toolsets") | |
| if toolsetsSet { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new Docker examples for
--toolsets/--toolsomit the required Vault connection env vars (e.g.,VAULT_ADDR,VAULT_TOKEN) and transport settings used in the preceding command, so as written they won’t successfully start a functional server. Consider showing these flags as additions to the existingdocker run ... -e VAULT_ADDR=... -e VAULT_TOKEN=...example (or explicitly noting that the same env/network flags are still required).