Skip to content

Commit 255e98f

Browse files
committed
Acquire snapshots synchronously while blocking other requests/notifications
1 parent 2cf2617 commit 255e98f

File tree

1 file changed

+102
-91
lines changed

1 file changed

+102
-91
lines changed

internal/lsp/server.go

Lines changed: 102 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -415,32 +415,34 @@ func (s *Server) dispatchLoop(ctx context.Context) error {
415415
s.pendingClientRequestsMu.Unlock()
416416
}
417417

418-
handle := func() {
419-
if err := s.handleRequestOrNotification(requestCtx, req); err != nil {
420-
if errors.Is(err, context.Canceled) {
421-
if err := s.sendError(req.ID, lsproto.ErrorCodeRequestCancelled); err != nil {
422-
lspExit(err)
423-
}
424-
} else if errors.Is(err, io.EOF) {
425-
lspExit(nil)
426-
} else {
427-
if err := s.sendError(req.ID, err); err != nil {
428-
lspExit(err)
429-
}
418+
handleError := func(err error) {
419+
if errors.Is(err, context.Canceled) {
420+
if err := s.sendError(req.ID, lsproto.ErrorCodeRequestCancelled); err != nil {
421+
lspExit(err)
422+
}
423+
} else if errors.Is(err, io.EOF) {
424+
lspExit(nil)
425+
} else {
426+
if err := s.sendError(req.ID, err); err != nil {
427+
lspExit(err)
430428
}
431429
}
430+
}
432431

433-
if req.ID != nil {
434-
s.pendingClientRequestsMu.Lock()
435-
delete(s.pendingClientRequests, *req.ID)
436-
s.pendingClientRequestsMu.Unlock()
437-
}
432+
if doAsyncWork, err := s.handleRequestOrNotification(requestCtx, req); err != nil {
433+
handleError(err)
434+
} else if doAsyncWork != nil {
435+
go func() {
436+
if lsError := doAsyncWork(); lsError != nil {
437+
handleError(lsError)
438+
}
439+
}()
438440
}
439441

440-
if isBlockingMethod(req.Method) {
441-
handle()
442-
} else {
443-
go handle()
442+
if req.ID != nil {
443+
s.pendingClientRequestsMu.Lock()
444+
delete(s.pendingClientRequests, *req.ID)
445+
s.pendingClientRequestsMu.Unlock()
444446
}
445447
}
446448
}
@@ -532,27 +534,46 @@ func (s *Server) send(msg *lsproto.Message) error {
532534
}
533535
}
534536

535-
func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.RequestMessage) error {
537+
// handleRequestOrNotification looks up the handler for the given request or notification, executes its synchronous work
538+
// and returns any asynchronous work as a function to be executed by the caller.
539+
func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.RequestMessage) (func() error, error) {
536540
ctx = lsproto.WithClientCapabilities(ctx, &s.clientCapabilities)
537541

538542
if handler := handlers()[req.Method]; handler != nil {
539543
start := time.Now()
540-
err := handler(s, ctx, req)
544+
doAsyncWork, err := handler(s, ctx, req)
541545
idStr := ""
542546
if req.ID != nil {
543547
idStr = " (" + req.ID.String() + ")"
544548
}
549+
if err != nil {
550+
s.logger.Error("error handling method '", req.Method, "'", idStr, ": ", err)
551+
return nil, err
552+
}
553+
if doAsyncWork != nil {
554+
return func() error {
555+
if ctx.Err() != nil {
556+
return ctx.Err()
557+
}
558+
asyncWorkErr := doAsyncWork()
559+
s.logger.Info(core.IfElse(asyncWorkErr != nil, "error handling method '", "handled method '"), req.Method, "'", idStr, " in ", time.Since(start))
560+
return asyncWorkErr
561+
}, nil
562+
}
545563
s.logger.Info("handled method '", req.Method, "'", idStr, " in ", time.Since(start))
546-
return err
564+
return nil, nil
547565
}
548566
s.logger.Warn("unknown method '", req.Method, "'")
549567
if req.ID != nil {
550-
return s.sendError(req.ID, lsproto.ErrorCodeInvalidRequest)
568+
return nil, s.sendError(req.ID, lsproto.ErrorCodeInvalidRequest)
551569
}
552-
return nil
570+
return nil, nil
553571
}
554572

555-
type handlerMap map[lsproto.Method]func(*Server, context.Context, *lsproto.RequestMessage) error
573+
// handlerMap maps LSP method to a handler function. The handler function executes any work that must be done synchronously
574+
// before other requests/notifications can be processed, and returns any additional work as a function to be executed
575+
// asynchronously after the synchronous work is complete.
576+
type handlerMap map[lsproto.Method]func(*Server, context.Context, *lsproto.RequestMessage) (func() error, error)
556577

557578
var handlers = sync.OnceValue(func() handlerMap {
558579
handlers := make(handlerMap)
@@ -615,9 +636,9 @@ var handlers = sync.OnceValue(func() handlerMap {
615636
})
616637

617638
func registerNotificationHandler[Req any](handlers handlerMap, info lsproto.NotificationInfo[Req], fn func(*Server, context.Context, Req) error) {
618-
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error {
639+
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) {
619640
if s.session == nil && req.Method != lsproto.MethodInitialized {
620-
return lsproto.ErrorCodeServerNotInitialized
641+
return nil, lsproto.ErrorCodeServerNotInitialized
621642
}
622643

623644
var params Req
@@ -626,9 +647,9 @@ func registerNotificationHandler[Req any](handlers handlerMap, info lsproto.Noti
626647
params = req.Params.(Req)
627648
}
628649
if err := fn(s, ctx, params); err != nil {
629-
return err
650+
return nil, err
630651
}
631-
return ctx.Err()
652+
return nil, ctx.Err()
632653
}
633654
}
634655

@@ -637,9 +658,9 @@ func registerRequestHandler[Req, Resp any](
637658
info lsproto.RequestInfo[Req, Resp],
638659
fn func(*Server, context.Context, Req, *lsproto.RequestMessage) (Resp, error),
639660
) {
640-
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error {
661+
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) {
641662
if s.session == nil && req.Method != lsproto.MethodInitialize {
642-
return lsproto.ErrorCodeServerNotInitialized
663+
return nil, lsproto.ErrorCodeServerNotInitialized
643664
}
644665

645666
var params Req
@@ -649,71 +670,75 @@ func registerRequestHandler[Req, Resp any](
649670
}
650671
resp, err := fn(s, ctx, params, req)
651672
if err != nil {
652-
return err
673+
return nil, err
653674
}
654675
if ctx.Err() != nil {
655-
return ctx.Err()
676+
return nil, ctx.Err()
656677
}
657-
return s.sendResult(req.ID, resp)
678+
return nil, s.sendResult(req.ID, resp)
658679
}
659680
}
660681

661682
func registerLanguageServiceDocumentRequestHandler[Req lsproto.HasTextDocumentURI, Resp any](handlers handlerMap, info lsproto.RequestInfo[Req, Resp], fn func(*Server, context.Context, *ls.LanguageService, Req) (Resp, error)) {
662-
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error {
683+
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) {
663684
var params Req
664685
// Ignore empty params.
665686
if req.Params != nil {
666687
params = req.Params.(Req)
667688
}
668689
ls, err := s.session.GetLanguageService(ctx, params.TextDocumentURI())
669690
if err != nil {
670-
return err
691+
return nil, err
671692
}
672-
defer s.recover(ctx, req)
673-
resp, err := fn(s, ctx, ls, params)
674-
if err != nil {
675-
return err
676-
}
677-
if ctx.Err() != nil {
678-
return ctx.Err()
679-
}
680-
return s.sendResult(req.ID, resp)
693+
return func() error {
694+
defer s.recover(ctx, req)
695+
resp, lsErr := fn(s, ctx, ls, params)
696+
if lsErr != nil {
697+
return lsErr
698+
}
699+
if ctx.Err() != nil {
700+
return ctx.Err()
701+
}
702+
return s.sendResult(req.ID, resp)
703+
}, nil
681704
}
682705
}
683706

684707
func registerLanguageServiceWithAutoImportsRequestHandler[Req lsproto.HasTextDocumentURI, Resp any](handlers handlerMap, info lsproto.RequestInfo[Req, Resp], fn func(*Server, context.Context, *ls.LanguageService, Req) (Resp, error)) {
685-
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error {
708+
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) {
686709
var params Req
687710
// Ignore empty params.
688711
if req.Params != nil {
689712
params = req.Params.(Req)
690713
}
691714
languageService, err := s.session.GetLanguageService(ctx, params.TextDocumentURI())
692715
if err != nil {
693-
return err
716+
return nil, err
694717
}
695-
defer s.recover(ctx, req)
696-
resp, err := fn(s, ctx, languageService, params)
697-
if errors.Is(err, ls.ErrNeedsAutoImports) {
698-
languageService, err = s.session.GetLanguageServiceWithAutoImports(ctx, params.TextDocumentURI())
699-
if err != nil {
700-
return err
718+
return func() error {
719+
defer s.recover(ctx, req)
720+
resp, lsErr := fn(s, ctx, languageService, params)
721+
if errors.Is(lsErr, ls.ErrNeedsAutoImports) {
722+
languageService, lsErr = s.session.GetLanguageServiceWithAutoImports(ctx, params.TextDocumentURI())
723+
if lsErr != nil {
724+
return lsErr
725+
}
726+
if ctx.Err() != nil {
727+
return ctx.Err()
728+
}
729+
resp, lsErr = fn(s, ctx, languageService, params)
730+
if errors.Is(lsErr, ls.ErrNeedsAutoImports) {
731+
panic(info.Method + " returned ErrNeedsAutoImports even after enabling auto imports")
732+
}
733+
}
734+
if lsErr != nil {
735+
return lsErr
701736
}
702737
if ctx.Err() != nil {
703738
return ctx.Err()
704739
}
705-
resp, err = fn(s, ctx, languageService, params)
706-
if errors.Is(err, ls.ErrNeedsAutoImports) {
707-
panic(info.Method + " returned ErrNeedsAutoImports even after enabling auto imports")
708-
}
709-
}
710-
if err != nil {
711-
return err
712-
}
713-
if ctx.Err() != nil {
714-
return ctx.Err()
715-
}
716-
return s.sendResult(req.ID, resp)
740+
return s.sendResult(req.ID, resp)
741+
}, nil
717742
}
718743
}
719744

@@ -722,7 +747,7 @@ func registerMultiProjectReferenceRequestHandler[Req lsproto.HasTextDocumentPosi
722747
info lsproto.RequestInfo[Req, Resp],
723748
fn func(*ls.LanguageService, context.Context, Req, ls.CrossProjectOrchestrator) (Resp, error),
724749
) {
725-
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error {
750+
handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) {
726751
var params Req
727752
// Ignore empty params.
728753
if req.Params != nil {
@@ -731,14 +756,16 @@ func registerMultiProjectReferenceRequestHandler[Req lsproto.HasTextDocumentPosi
731756
// !!! sheetal: multiple projects that contain the file through symlinks
732757
defaultLs, orchestrator, err := s.getLanguageServiceAndCrossProjectOrchestrator(ctx, params.TextDocumentURI(), req)
733758
if err != nil {
734-
return err
735-
}
736-
defer s.recover(ctx, req)
737-
resp, err := fn(defaultLs, ctx, params, orchestrator)
738-
if err != nil {
739-
return err
759+
return nil, err
740760
}
741-
return s.sendResult(req.ID, resp)
761+
return func() error {
762+
defer s.recover(ctx, req)
763+
resp, lsErr := fn(defaultLs, ctx, params, orchestrator)
764+
if lsErr != nil {
765+
return lsErr
766+
}
767+
return s.sendResult(req.ID, resp)
768+
}, nil
742769
}
743770
}
744771

@@ -1343,22 +1370,6 @@ func (s *Server) NpmInstall(cwd string, args []string) ([]byte, error) {
13431370
return s.npmInstall(cwd, args)
13441371
}
13451372

1346-
func isBlockingMethod(method lsproto.Method) bool {
1347-
switch method {
1348-
case lsproto.MethodInitialize,
1349-
lsproto.MethodInitialized,
1350-
lsproto.MethodTextDocumentDidOpen,
1351-
lsproto.MethodTextDocumentDidChange,
1352-
lsproto.MethodTextDocumentDidSave,
1353-
lsproto.MethodTextDocumentDidClose,
1354-
lsproto.MethodWorkspaceDidChangeWatchedFiles,
1355-
lsproto.MethodWorkspaceDidChangeConfiguration,
1356-
lsproto.MethodWorkspaceConfiguration:
1357-
return true
1358-
}
1359-
return false
1360-
}
1361-
13621373
func ptrTo[T any](v T) *T {
13631374
return &v
13641375
}

0 commit comments

Comments
 (0)