Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ func TestEndToEnd(t *testing.T) {
var ct, st Transport = NewInMemoryTransports()

// Channels to check if notification callbacks happened.
// These test asynchronous sending of notifications after a small delay (see
// Server.sendNotification).
notificationChans := map[string]chan int{}
for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe", "elicitation_complete"} {
notificationChans[name] = make(chan int, 1)
Expand Down Expand Up @@ -1695,14 +1697,15 @@ func TestSynchronousNotifications(t *testing.T) {
},
}
server := NewServer(testImpl, serverOpts)
cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) {
addTool := func(s *Server) {
AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
if !rootsChanged.Load() {
return nil, nil, fmt.Errorf("didn't get root change notification")
}
return new(CallToolResult), nil, nil
})
})
}
cs, ss, cleanup := basicClientServerConnection(t, client, server, addTool)
defer cleanup()

t.Run("from client", func(t *testing.T) {
Expand All @@ -1717,7 +1720,11 @@ func TestSynchronousNotifications(t *testing.T) {
})

t.Run("from server", func(t *testing.T) {
server.RemoveTools("tool")
// Because server change notifications are batched, we must generate a lot of them.
for range maxPendingNotifications/2 + 1 {
server.RemoveTools("tool")
addTool(server)
}
if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil {
t.Errorf("CreateMessage failed: %v", err)
}
Expand Down
69 changes: 52 additions & 17 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type Server struct {
sendingMethodHandler_ MethodHandler
receivingMethodHandler_ MethodHandler
resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool
pendingNotifications map[string]int // notification name -> count of unsent changes
}

// ServerOptions is used to configure behavior of the server.
Expand Down Expand Up @@ -149,6 +150,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession],
resourceSubscriptions: make(map[string]map[*ServerSession]bool),
pendingNotifications: make(map[string]int),
}
}

Expand All @@ -158,15 +160,13 @@ func (s *Server) AddPrompt(p *Prompt, h PromptHandler) {
// (It's possible an item was replaced with an identical one, but not worth checking.)
s.changeAndNotify(
notificationPromptListChanged,
&PromptListChangedParams{},
func() bool { s.prompts.add(&serverPrompt{p, h}); return true })
}

// RemovePrompts removes the prompts with the given names.
// It is not an error to remove a nonexistent prompt.
func (s *Server) RemovePrompts(names ...string) {
s.changeAndNotify(notificationPromptListChanged, &PromptListChangedParams{},
func() bool { return s.prompts.remove(names...) })
s.changeAndNotify(notificationPromptListChanged, func() bool { return s.prompts.remove(names...) })
}

// AddTool adds a [Tool] to the server, or replaces one with the same name.
Expand Down Expand Up @@ -235,8 +235,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
// (It's possible a tool was replaced with an identical one, but not worth checking.)
// TODO: Batch these changes by size and time? The typescript SDK doesn't.
// TODO: Surface notify error here? best not, in case we need to batch.
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
func() bool { s.tools.add(st); return true })
s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true })
}

func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
Expand Down Expand Up @@ -419,14 +418,13 @@ func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
// RemoveTools removes the tools with the given names.
// It is not an error to remove a nonexistent tool.
func (s *Server) RemoveTools(names ...string) {
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
func() bool { return s.tools.remove(names...) })
s.changeAndNotify(notificationToolListChanged, func() bool { return s.tools.remove(names...) })
}

// AddResource adds a [Resource] to the server, or replaces one with the same URI.
// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme).
func (s *Server) AddResource(r *Resource, h ResourceHandler) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
s.changeAndNotify(notificationResourceListChanged,
func() bool {
if _, err := url.Parse(r.URI); err != nil {
panic(err) // url.Parse includes the URI in the error
Expand All @@ -439,14 +437,13 @@ func (s *Server) AddResource(r *Resource, h ResourceHandler) {
// RemoveResources removes the resources with the given URIs.
// It is not an error to remove a nonexistent resource.
func (s *Server) RemoveResources(uris ...string) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
func() bool { return s.resources.remove(uris...) })
s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resources.remove(uris...) })
}

// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI.
// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme).
func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
s.changeAndNotify(notificationResourceListChanged,
func() bool {
// Validate the URI template syntax
_, err := uritemplate.New(t.URITemplate)
Expand All @@ -461,8 +458,7 @@ func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) {
// RemoveResourceTemplates removes the resource templates with the given URI templates.
// It is not an error to remove a nonexistent resource.
func (s *Server) RemoveResourceTemplates(uriTemplates ...string) {
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
func() bool { return s.resourceTemplates.remove(uriTemplates...) })
s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resourceTemplates.remove(uriTemplates...) })
}

func (s *Server) capabilities() *ServerCapabilities {
Expand Down Expand Up @@ -497,18 +493,57 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR
return s.opts.CompletionHandler(ctx, req)
}

// Map from notification name to its corresponding params. The params have no fields,
// so a single struct can be reused.
var changeNotificationParams = map[string]Params{
notificationToolListChanged: &ToolListChangedParams{},
notificationPromptListChanged: &PromptListChangedParams{},
notificationResourceListChanged: &ResourceListChangedParams{},
}

// The maximum number of change notifications of a particular type (e.g. tools-changed)
// that can be pending.
const maxPendingNotifications = 10

// How long to wait before sending a change notification.
var notificationDelay = 50 * time.Millisecond

// changeAndNotify is called when a feature is added or removed.
// It calls change, which should do the work and report whether a change actually occurred.
// If there was a change, it notifies a snapshot of the sessions.
func (s *Server) changeAndNotify(notification string, params Params, change func() bool) {
func (s *Server) changeAndNotify(notification string, change func() bool) {
var sessions []*ServerSession
// Lock for the change, but not for the notification.
send := false
s.mu.Lock()
if change() {
sessions = slices.Clone(s.sessions)
pending := s.pendingNotifications[notification]
if pending >= maxPendingNotifications {
send = true
pending = 0
// Make a local copy of the session list so we can use it without holding the lock.
sessions = slices.Clone(s.sessions)
} else {
pending++
if pending == 1 {
time.AfterFunc(notificationDelay, func() { s.sendNotification(notification) })
}
}
s.pendingNotifications[notification] = pending
}
s.mu.Unlock() // Don't hold lock during notifications.
if send {
notifySessions(sessions, notification, changeNotificationParams[notification])
}
}

// sendNotification is called asynchronously to ensure that notifications are sent
// soon after they occur.
func (s *Server) sendNotification(n string) {
s.mu.Lock()
sessions := slices.Clone(s.sessions)
s.pendingNotifications[n] = 0
s.mu.Unlock()
notifySessions(sessions, notification, params)
notifySessions(sessions, n, changeNotificationParams[n])
}

// Sessions returns an iterator that yields the current set of server sessions.
Expand Down