Skip to content

Commit 308fb63

Browse files
authored
[common][rpc/membership] Fx integration fixes (#6955)
* [common][rpc/membership] Fx integration fixes
1 parent 494314a commit 308fb63

File tree

6 files changed

+94
-26
lines changed

6 files changed

+94
-26
lines changed

common/membership/membershipfx/membershipfx.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/uber/cadence/common/log"
3232
"github.com/uber/cadence/common/membership"
3333
"github.com/uber/cadence/common/metrics"
34+
"github.com/uber/cadence/common/rpc"
3435
"github.com/uber/cadence/common/service"
3536
)
3637

@@ -41,6 +42,7 @@ type buildMembershipParams struct {
4142
fx.In
4243

4344
Clock clock.TimeSource
45+
RPCFactory rpc.Factory
4446
PeerProvider membership.PeerProvider
4547
Logger log.Logger
4648
MetricsClient metrics.Client
@@ -70,10 +72,25 @@ func buildMembership(params buildMembershipParams) (buildMembershipResult, error
7072
return buildMembershipResult{}, fmt.Errorf("create resolver: %w", err)
7173
}
7274

73-
params.Lifecycle.Append(fx.StartStopHook(resolver.Start, resolver.Stop))
75+
params.Lifecycle.Append(fx.StartStopHook(startResolver(resolver, params.RPCFactory), resolver.Stop))
7476

7577
return buildMembershipResult{
7678
Rings: rings,
7779
Resolver: resolver,
7880
}, nil
7981
}
82+
83+
func startResolver(resolver membership.Resolver, rpcFactory rpc.Factory) func() error {
84+
return func() error {
85+
err := rpcFactory.Start(resolver)
86+
if err != nil {
87+
return fmt.Errorf("start rpc factory: %w", err)
88+
}
89+
err = rpcFactory.GetDispatcher().Start()
90+
if err != nil {
91+
return fmt.Errorf("start rpc factory dispatcher: %w", err)
92+
}
93+
resolver.Start()
94+
return nil
95+
}
96+
}

common/membership/membershipfx/memebershipfx_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,42 @@ import (
2828
"go.uber.org/fx"
2929
"go.uber.org/fx/fxtest"
3030
"go.uber.org/mock/gomock"
31+
"go.uber.org/yarpc"
3132

3233
"github.com/uber/cadence/common/clock"
3334
"github.com/uber/cadence/common/log"
3435
"github.com/uber/cadence/common/log/testlogger"
3536
"github.com/uber/cadence/common/membership"
3637
"github.com/uber/cadence/common/metrics"
38+
"github.com/uber/cadence/common/rpc"
39+
"github.com/uber/cadence/common/service"
3740
)
3841

3942
func TestFxStartStop(t *testing.T) {
4043
app := fxtest.New(t, fx.Provide(func() appParams {
4144
ctrl := gomock.NewController(t)
4245
provider := membership.NewMockPeerProvider(ctrl)
46+
provider.EXPECT().Start()
47+
provider.EXPECT().Stop()
48+
for _, s := range service.ListWithRing {
49+
provider.EXPECT().Subscribe(s, gomock.Any()).Return(nil)
50+
provider.EXPECT().GetMembers(s).Return([]membership.HostInfo{}, nil)
51+
// this is also called by every ring, but could be called multiple times.
52+
provider.EXPECT().Stop()
53+
}
54+
factory := rpc.NewMockFactory(ctrl)
55+
factory.EXPECT().Start(gomock.Any()).Return(nil)
56+
factory.EXPECT().GetDispatcher().Return(yarpc.NewDispatcher(yarpc.Config{
57+
Name: "membership_test",
58+
}))
4359
return appParams{
4460
Clock: clock.NewMockedTimeSource(),
4561
PeerProvider: provider,
4662
Logger: testlogger.New(t),
4763
MetricsClient: metrics.NewNoopMetricsClient(),
64+
RPCFactory: factory,
4865
}
49-
}), Module)
66+
}), Module, fx.Invoke(func(resolver membership.Resolver) {}))
5067
app.RequireStart().RequireStop()
5168
}
5269

@@ -57,4 +74,5 @@ type appParams struct {
5774
PeerProvider membership.PeerProvider
5875
Logger log.Logger
5976
MetricsClient metrics.Client
77+
RPCFactory rpc.Factory
6078
}

common/peerprovider/ringpopprovider/ringpopfx/ringpopfx.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,5 @@ func buildRingpopProvider(params Params) (membership.PeerProvider, error) {
5454
if err != nil {
5555
return nil, err
5656
}
57-
params.Lifecycle.Append(fx.StartStopHook(provider.Start, provider.Stop))
5857
return provider, nil
5958
}

common/peerprovider/ringpopprovider/ringpopfx/ringpopfx_test.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@ package ringpopfx
2525
import (
2626
"testing"
2727

28+
"github.com/stretchr/testify/require"
29+
"github.com/uber/tchannel-go"
2830
"go.uber.org/fx"
2931
"go.uber.org/fx/fxtest"
3032
"go.uber.org/mock/gomock"
3133

3234
"github.com/uber/cadence/common/config"
3335
"github.com/uber/cadence/common/log"
3436
"github.com/uber/cadence/common/log/testlogger"
37+
"github.com/uber/cadence/common/membership"
38+
"github.com/uber/cadence/common/peerprovider/ringpopprovider"
3539
"github.com/uber/cadence/common/rpc"
3640
)
3741

@@ -41,23 +45,32 @@ func TestFxApp(t *testing.T) {
4145
func() testSetupParams {
4246
ctrl := gomock.NewController(t)
4347
factory := rpc.NewMockFactory(ctrl)
44-
factory.EXPECT().GetTChannel().Return(nil)
48+
tch, err := tchannel.NewChannel("test-ringpop", nil)
49+
require.NoError(t, err)
50+
factory.EXPECT().GetTChannel().Return(tch)
4551

4652
return testSetupParams{
4753
Service: "test",
4854
Logger: testlogger.New(t),
4955
RPCFactory: factory,
56+
Config: config.Config{
57+
Ringpop: ringpopprovider.Config{
58+
Name: "test-ringpop",
59+
BootstrapMode: ringpopprovider.BootstrapModeHosts,
60+
BootstrapHosts: []string{"127.0.0.1:7933", "127.0.0.1:7934", "127.0.0.1:7935"},
61+
},
62+
},
5063
}
5164
}),
52-
Module,
65+
Module, fx.Invoke(func(provider membership.PeerProvider) {}),
5366
)
5467
app.RequireStart().RequireStop()
5568
}
5669

5770
type testSetupParams struct {
5871
fx.Out
5972

60-
Service string `name:"service"`
73+
Service string `name:"service-full-name"`
6174
Config config.Config
6275
ServiceConfig config.Service
6376
Logger log.Logger

common/rpc/factory.go

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ var (
5252

5353
// Factory is an implementation of rpc.Factory interface
5454
type FactoryImpl struct {
55+
startOnce sync.Once
56+
stopOnce sync.Once
5557
maxMessageSize int
5658
channel tchannel.Channel
5759
dispatcher *yarpc.Dispatcher
@@ -182,33 +184,38 @@ func (d *FactoryImpl) GetMaxMessageSize() int {
182184
}
183185

184186
func (d *FactoryImpl) Start(peerLister PeerLister) error {
185-
d.peerLister = peerLister
186-
// subscribe to membership changes for history and matching. This is needed to update the peers for rpc
187-
for _, svc := range servicesToTalkP2P {
188-
ch := make(chan *membership.ChangedEvent, 1)
189-
if err := d.peerLister.Subscribe(svc, factoryComponentName, ch); err != nil {
190-
return fmt.Errorf("rpc factory failed to subscribe to membership updates for svc: %v, err: %v", svc, err)
187+
var err error
188+
d.startOnce.Do(func() {
189+
d.peerLister = peerLister
190+
// subscribe to membership changes for history and matching. This is needed to update the peers for rpc
191+
for _, svc := range servicesToTalkP2P {
192+
ch := make(chan *membership.ChangedEvent, 1)
193+
if err = d.peerLister.Subscribe(svc, factoryComponentName, ch); err != nil {
194+
err = fmt.Errorf("rpc factory failed to subscribe to membership updates for svc: %v, err: %w", svc, err)
195+
return
196+
}
197+
d.wg.Add(1)
198+
go d.listenMembershipChanges(svc, ch)
191199
}
192-
d.wg.Add(1)
193-
go d.listenMembershipChanges(svc, ch)
194-
}
195-
196-
return nil
200+
})
201+
return err
197202
}
198203

199204
func (d *FactoryImpl) Stop() error {
200-
d.logger.Info("stopping rpc factory")
205+
d.stopOnce.Do(func() {
206+
d.logger.Info("stopping rpc factory")
201207

202-
for _, svc := range servicesToTalkP2P {
203-
if err := d.peerLister.Unsubscribe(svc, factoryComponentName); err != nil {
204-
d.logger.Error("rpc factory failed to unsubscribe from membership updates", tag.Error(err), tag.Service(svc))
208+
for _, svc := range servicesToTalkP2P {
209+
if err := d.peerLister.Unsubscribe(svc, factoryComponentName); err != nil {
210+
d.logger.Error("rpc factory failed to unsubscribe from membership updates", tag.Error(err), tag.Service(svc))
211+
}
205212
}
206-
}
207213

208-
d.cancelFn()
209-
d.wg.Wait()
214+
d.cancelFn()
215+
d.wg.Wait()
210216

211-
d.logger.Info("stopped rpc factory")
217+
d.logger.Info("stopped rpc factory")
218+
})
212219
return nil
213220
}
214221

common/rpc/rpcfx/rpcfx.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ type factoryParams struct {
6969

7070
func buildFactory(p factoryParams) rpc.Factory {
7171
res := rpc.NewFactory(p.Logger, p.RPCParams)
72-
p.Lifecycle.Append(fx.StartStopHook(res.GetDispatcher().Start, res.GetDispatcher().Stop))
72+
p.Lifecycle.Append(fx.StopHook(rpcStopper(res)))
7373
return res
7474
}
75+
76+
func rpcStopper(factory rpc.Factory) func() error {
77+
return func() error {
78+
err := factory.GetDispatcher().Stop()
79+
if err != nil {
80+
return fmt.Errorf("dispatcher stop: %w", err)
81+
}
82+
err = factory.Stop()
83+
if err != nil {
84+
return fmt.Errorf("factory stop: %w", err)
85+
}
86+
return nil
87+
}
88+
}

0 commit comments

Comments
 (0)