Skip to content

Commit 7e29937

Browse files
committed
feat: implement MCP workspace roots integration - Add automatic extraction of client workspace roots during initialization - Provide convenient Context methods: GetRoots(), GetPrimaryRoot(), InRoots() - Follow Go idioms by extending existing server.roots system - Add comprehensive test coverage for workspace functionality - Update API documentation with workspace roots examples - Closes task #46
1 parent 0076f71 commit 7e29937

File tree

5 files changed

+727
-12
lines changed

5 files changed

+727
-12
lines changed

docs/api-reference/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,40 @@ server.Prompt("promptName", "description", template)
6767
server.Run()
6868
```
6969

70+
## Workspace Roots Integration
71+
72+
GOMCP servers automatically integrate with the MCP roots protocol to provide workspace context to tools. This eliminates the need for manual project_root parameters.
73+
74+
### Context API
75+
76+
Tools receive workspace information through the context:
77+
78+
```go
79+
func MyTool(ctx *server.Context, args struct{}) (interface{}, error) {
80+
// Get all workspace roots
81+
roots := ctx.GetRoots()
82+
83+
// Get primary workspace root
84+
primaryRoot := ctx.GetPrimaryRoot()
85+
86+
// Check if path is within workspace
87+
isInWorkspace := ctx.InRoots("/path/to/file")
88+
89+
return map[string]interface{}{
90+
"primary_root": primaryRoot,
91+
"all_roots": roots,
92+
"in_workspace": isInWorkspace,
93+
}, nil
94+
}
95+
```
96+
97+
### Features
98+
99+
- Automatic extraction of workspace roots from MCP client initialization
100+
- Thread-safe access to workspace context
101+
- Convenient helper methods for path validation
102+
- No manual configuration required
103+
70104
## Generating Documentation
71105

72106
API documentation is automatically generated from source code comments. For local documentation:

examples/cancellation/server/server.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func longRunningTool(ctx *server.Context, args struct {
2121
return "", fmt.Errorf("duration must be positive")
2222
}
2323

24-
fmt.Printf("Starting long-running task for %d seconds...\n", args.Duration)
24+
ctx.Logger.Info("Starting long-running task", "duration", args.Duration)
2525

2626
// Register for cancellation
2727
cancelCh := ctx.RegisterForCancellation()
@@ -32,25 +32,25 @@ func longRunningTool(ctx *server.Context, args struct {
3232

3333
// Method 1: Using the convenience method
3434
if err := ctx.CheckCancellation(); err != nil {
35-
fmt.Println("Task cancelled (using CheckCancellation)")
35+
ctx.Logger.Info("Task cancelled (using CheckCancellation)")
3636
return "", fmt.Errorf("task cancelled after %d seconds", i)
3737
}
3838

3939
// Method 2: Using the cancelCh directly
4040
select {
4141
case <-cancelCh:
42-
fmt.Println("Task cancelled (using cancel channel)")
42+
ctx.Logger.Info("Task cancelled (using cancel channel)")
4343
return "", fmt.Errorf("task cancelled after %d seconds", i)
4444
default:
4545
// Not cancelled, continue work
4646
}
4747

4848
// Do some "work"
49-
fmt.Printf("Working... %d/%d seconds completed\n", i+1, args.Duration)
49+
ctx.Logger.Info("Working...", "progress", fmt.Sprintf("%d/%d seconds completed", i+1, args.Duration))
5050
time.Sleep(1 * time.Second)
5151
}
5252

53-
fmt.Println("Task completed successfully!")
53+
ctx.Logger.Info("Task completed successfully!")
5454
return fmt.Sprintf("Completed task that took %d seconds", args.Duration), nil
5555
}
5656

@@ -63,7 +63,7 @@ func sendCancellation(srv server.Server, requestID string) {
6363
time.Sleep(2 * time.Second)
6464

6565
// Send the cancellation notification
66-
fmt.Println("Sending cancellation notification...")
66+
srv.Logger().Info("Sending cancellation notification...")
6767
err := s.SendCancelledNotification(requestID, "User requested cancellation")
6868
if err != nil {
6969
fmt.Printf("Error sending cancellation: %v\n", err)
@@ -109,14 +109,14 @@ func main() {
109109
// Print the response
110110
var response map[string]interface{}
111111
json.Unmarshal(responseBytes, &response)
112-
fmt.Println("\nResponse received:")
112+
impl.Logger().Info("Response received:")
113113
prettyJSON, _ := json.MarshalIndent(response, "", " ")
114-
fmt.Println(string(prettyJSON))
114+
impl.Logger().Info(string(prettyJSON))
115115

116116
// Also demonstrate cancellation in the real server
117-
fmt.Println("\nStarting real server example...")
117+
impl.Logger().Info("Starting real server example...")
118118
if err := srv.Run(); err != nil {
119-
fmt.Fprintf(os.Stderr, "Server error: %v\n", err)
119+
impl.Logger().Error("Server error", "error", err)
120120
os.Exit(1)
121121
}
122122
}

server/context.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,3 +767,28 @@ func (c *Context) ValidateToolArgs(toolName string) (interface{}, error) {
767767
// For now, return the raw arguments
768768
return c.Request.ToolArgs, nil
769769
}
770+
771+
// GetRoots returns all registered root paths from the server
772+
func (c *Context) GetRoots() []string {
773+
if c.server != nil {
774+
return c.server.GetRoots()
775+
}
776+
return []string{}
777+
}
778+
779+
// GetPrimaryRoot returns the first registered root path, if any
780+
func (c *Context) GetPrimaryRoot() string {
781+
roots := c.GetRoots()
782+
if len(roots) > 0 {
783+
return roots[0]
784+
}
785+
return ""
786+
}
787+
788+
// InRoots checks if a path is within any registered root
789+
func (c *Context) InRoots(path string) bool {
790+
if c.server != nil {
791+
return c.server.IsPathInRoots(path)
792+
}
793+
return false
794+
}

server/server.go

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"fmt"
99
"log/slog"
10+
"net/url"
1011
"os"
1112
"sync"
1213
"time"
@@ -699,12 +700,17 @@ func (s *serverImpl) ProcessInitialize(ctx *Context) (interface{}, error) {
699700
// Store the validated protocol version without locking
700701
s.protocolVersion = protocolVersion
701702

703+
// Extract and add workspace roots from client initialization
704+
workspaceRoots := extractWorkspaceRoots(ctx.Request.Params)
705+
if len(workspaceRoots) > 0 {
706+
s.roots = append(s.roots, workspaceRoots...)
707+
s.logger.Debug("added workspace roots from client", "roots", workspaceRoots)
708+
}
709+
702710
// Update the transport with the negotiated protocol version
703-
s.mu.RLock()
704711
if s.transport != nil {
705712
s.transport.SetProtocolVersion(protocolVersion)
706713
}
707-
s.mu.RUnlock()
708714

709715
// Determine sampling capabilities based on protocol version
710716
samplingCaps := DetectClientCapabilities(protocolVersion)
@@ -1017,6 +1023,81 @@ func (s *serverImpl) handleInitializedNotification() {
10171023
}()
10181024
}
10191025

1026+
// extractWorkspaceRoots extracts workspace root paths from initialization parameters
1027+
func extractWorkspaceRoots(params interface{}) []string {
1028+
if params == nil {
1029+
return nil
1030+
}
1031+
1032+
// Handle both parsed maps and JSON byte slices
1033+
var paramsMap map[string]interface{}
1034+
1035+
switch p := params.(type) {
1036+
case map[string]interface{}:
1037+
paramsMap = p
1038+
case json.RawMessage:
1039+
// Parse JSON bytes
1040+
if err := json.Unmarshal(p, &paramsMap); err != nil {
1041+
return nil
1042+
}
1043+
case []byte:
1044+
// Parse JSON bytes
1045+
if err := json.Unmarshal(p, &paramsMap); err != nil {
1046+
return nil
1047+
}
1048+
default:
1049+
return nil
1050+
}
1051+
1052+
// Look for clientInfo.roots according to MCP spec
1053+
clientInfo, ok := paramsMap["clientInfo"].(map[string]interface{})
1054+
if !ok {
1055+
return nil
1056+
}
1057+
1058+
roots, ok := clientInfo["roots"].([]interface{})
1059+
if !ok {
1060+
return nil
1061+
}
1062+
1063+
var result []string
1064+
for _, root := range roots {
1065+
if rootMap, ok := root.(map[string]interface{}); ok {
1066+
if uri, ok := rootMap["uri"].(string); ok {
1067+
// Convert file:// URIs to file paths
1068+
if path := uriToPath(uri); path != "" {
1069+
result = append(result, path)
1070+
}
1071+
}
1072+
}
1073+
}
1074+
1075+
return result
1076+
}
1077+
1078+
// uriToPath converts a file:// URI to a local file path
1079+
func uriToPath(uri string) string {
1080+
if uri == "" {
1081+
return ""
1082+
}
1083+
1084+
// Handle file:// URIs - must start with file:/// (three slashes for absolute paths)
1085+
if len(uri) > 8 && uri[:8] == "file:///" {
1086+
path := uri[7:] // Remove "file://" prefix, keeping the leading slash
1087+
1088+
// Handle URL decoding for special characters like %20 (space), %2B (+), etc.
1089+
if decoded, err := url.PathUnescape(path); err == nil {
1090+
return decoded
1091+
}
1092+
1093+
// If decoding fails, return the original path
1094+
return path
1095+
}
1096+
1097+
// Only file:/// URIs are supported (not file:// with server names)
1098+
return ""
1099+
}
1100+
10201101
// ListTools returns a list of all registered tools.
10211102
//
10221103
// This method provides programmatic access to the server's tool registry,

0 commit comments

Comments
 (0)