diff --git a/client/option_test.go b/client/option_test.go index 7fd8a7bf9e..c1ea33c2e1 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -44,6 +44,9 @@ import ( "github.com/cloudwego/kitex/pkg/http" "github.com/cloudwego/kitex/pkg/loadbalance" "github.com/cloudwego/kitex/pkg/proxy" + connpool2 "github.com/cloudwego/kitex/pkg/remote/connpool" + "github.com/cloudwego/kitex/pkg/remote/trans/gonet" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" "github.com/cloudwego/kitex/pkg/retry" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -764,6 +767,20 @@ func TestTailOption(t *testing.T) { test.Assert(t, opts.RemoteOpt.Dialer != nil) } +func TestGonetOption(t *testing.T) { + // gonet + opt := client.NewOptions([]Option{WithDialer(gonet.NewDialer()), WithLongConnection(connpool.IdleConfig{MaxIdlePerAddress: 10})}) + d := opt.RemoteOpt.ConnPool.(*connpool2.LongPool) + pcfg := d.Config() + test.Assert(t, pcfg.Enable) + + // netpoll + opt = client.NewOptions([]Option{WithDialer(netpoll.NewDialer()), WithLongConnection(connpool.IdleConfig{MaxIdlePerAddress: 10})}) + d = opt.RemoteOpt.ConnPool.(*connpool2.LongPool) + pcfg = d.Config() + test.Assert(t, !pcfg.Enable) +} + func checkOneOptionDebugInfo(t *testing.T, opt Option, expectStr string) error { o := &Options{} o.Apply([]Option{opt}) diff --git a/internal/client/option.go b/internal/client/option.go index 956e9a2bfa..00483a680f 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -24,6 +24,7 @@ import ( "github.com/cloudwego/localsession/backup" "github.com/cloudwego/kitex/internal/configutil" + internalRemote "github.com/cloudwego/kitex/internal/remote" "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/pkg/acl" "github.com/cloudwego/kitex/pkg/circuitbreak" @@ -292,22 +293,27 @@ func (o *Options) initRemoteOpt() { } o.RemoteOpt.TTHeaderStreamingProvider = ttstream.NewClientProvider(o.TTHeaderStreamingOptions.TransportOptions...) } + + _, setConnPoolProactiveCheck := o.RemoteOpt.Dialer.(internalRemote.IsGonetDialer) if o.RemoteOpt.ConnPool == nil { if o.PoolCfg != nil { if *o.PoolCfg == zero { o.RemoteOpt.ConnPool = connpool.NewShortPool(o.Svr.ServiceName) } else { - o.RemoteOpt.ConnPool = connpool.NewLongPool(o.Svr.ServiceName, *o.PoolCfg) + cfg := newDefaultLongPoolCfg(o.Svr.ServiceName, *o.PoolCfg, setConnPoolProactiveCheck) + o.RemoteOpt.ConnPool = connpool.NewLongPoolWithConfig(cfg) } } else { - o.RemoteOpt.ConnPool = connpool.NewLongPool( + cfg := newDefaultLongPoolCfg( o.Svr.ServiceName, connpool2.IdleConfig{ MaxIdlePerAddress: 10, MaxIdleGlobal: 100, MaxIdleTimeout: time.Minute, }, + setConnPoolProactiveCheck, ) + o.RemoteOpt.ConnPool = connpool.NewLongPoolWithConfig(cfg) } } } @@ -319,3 +325,15 @@ func (o *Options) InitRetryContainer() { o.CloseCallbacks = append(o.CloseCallbacks, o.UnaryOptions.RetryContainer.Close) } } + +func newDefaultLongPoolCfg(serviceName string, idleCfg connpool2.IdleConfig, enableProactiveCheck bool) connpool.LongPoolConfig { + return connpool.LongPoolConfig{ + ServiceName: serviceName, + IdleConfig: idleCfg, + ProactiveCheckConfig: connpool.ProactiveCheckConfig{ + Enable: enableProactiveCheck, + CheckFunc: internalRemote.ConnectionStateCheck, + Interval: connpool.DefaultProactiveConnCheckInterval, + }, + } +} diff --git a/internal/remote/conn_check.go b/internal/remote/conn_check.go new file mode 100644 index 0000000000..62e3737006 --- /dev/null +++ b/internal/remote/conn_check.go @@ -0,0 +1,70 @@ +//go:build !windows + +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package remote + +import ( + "fmt" + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +// ConnectionStateCheck uses unix.Poll to detect the connection state +// Since the connections are stored in the pool, we treat any POLLIN event as connection close and set the connection state to closed. +func ConnectionStateCheck(conns ...net.Conn) error { + pollFds := make([]unix.PollFd, 0, len(conns)) + + for _, conn := range conns { + sysConn, ok := conn.(syscall.Conn) + if !ok { + return fmt.Errorf("conn is not a syscall.Conn, got %T", conn) + } + rawConn, err := sysConn.SyscallConn() + if err != nil { + return err + } + var fd int + err = rawConn.Control(func(fileDescriptor uintptr) { + fd = int(fileDescriptor) + }) + if err != nil { + return err + } + pollFds = append(pollFds, unix.PollFd{Fd: int32(fd), Events: unix.POLLIN}) + } + + n, err := unix.Poll(pollFds, 0) + if err != nil { + return err + } + if n == 0 { + return nil + } + for i := 0; i < len(pollFds); i++ { + if pollFds[i].Revents&unix.POLLIN != 0 { + // the connection should not receive any data, POLLIN means FIN or RST + // set the state + if s, ok := conns[i].(SetConnState); ok { + s.SetConnState(true) + } + } + } + return nil +} diff --git a/internal/remote/conn_check_test.go b/internal/remote/conn_check_test.go new file mode 100644 index 0000000000..2d212ebe4c --- /dev/null +++ b/internal/remote/conn_check_test.go @@ -0,0 +1,86 @@ +//go:build !windows + +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package remote + +import ( + "errors" + "net" + "strings" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +var _ SetConnState = &mockConn{} + +type mockConn struct { + net.Conn + closed atomic.Bool +} + +func (m *mockConn) SetConnState(c bool) { + m.closed.Store(c) +} + +func (m *mockConn) SyscallConn() (syscall.RawConn, error) { + if sc, ok := m.Conn.(syscall.Conn); ok { + return sc.SyscallConn() + } + return nil, errors.New("not syscall.Conn") +} + +func TestConnectionStateCheck(t *testing.T) { + // wrong connection type + err := ConnectionStateCheck(net.Pipe()) + test.Assert(t, err != nil) + test.Assert(t, strings.Contains(err.Error(), "conn is not a syscall.Conn")) + + ln, err := net.Listen("tcp", "127.0.0.1:0") // 本地端口自动分配 + test.Assert(t, err == nil, err) + defer ln.Close() + + done := make(chan net.Conn) + go func() { + conn, e := ln.Accept() + test.Assert(t, e == nil) + done <- conn + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + test.Assert(t, err == nil, err) + + serverConn := <-done + serverConnWithState := &mockConn{Conn: serverConn} + // check, not closed + err = ConnectionStateCheck(serverConnWithState) + test.Assert(t, err == nil, err) + test.Assert(t, !serverConnWithState.closed.Load()) + + // close conn + clientConn.Close() + time.Sleep(100 * time.Millisecond) + + // check, closed + err = ConnectionStateCheck(serverConnWithState) + test.Assert(t, err == nil, err) + test.Assert(t, serverConnWithState.closed.Load()) +} diff --git a/internal/remote/conn_check_windows.go b/internal/remote/conn_check_windows.go new file mode 100644 index 0000000000..37d4722fc9 --- /dev/null +++ b/internal/remote/conn_check_windows.go @@ -0,0 +1,29 @@ +//go:build windows +// +build windows + +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package remote + +import ( + "net" +) + +// FIXME: windows not supported +func ConnectionStateCheck(conns ...net.Conn) error { + return nil +} diff --git a/internal/remote/gonet.go b/internal/remote/gonet.go new file mode 100644 index 0000000000..5de2156c85 --- /dev/null +++ b/internal/remote/gonet.go @@ -0,0 +1,27 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package remote + +// SetConnState only used to set the state to connection in gonet +type SetConnState interface { + SetConnState(inactive bool) +} + +// IsGonetDialer returns if the dialer is gonet dialer +type IsGonetDialer interface { + IsGonetDialer() bool +} diff --git a/pkg/remote/connpool/long_pool.go b/pkg/remote/connpool/long_pool.go index 50c0a78c53..56c8d5996d 100644 --- a/pkg/remote/connpool/long_pool.go +++ b/pkg/remote/connpool/long_pool.go @@ -26,6 +26,7 @@ import ( "time" "github.com/cloudwego/kitex/pkg/connpool" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/pkg/warmup" @@ -40,7 +41,8 @@ var ( ) const ( - configDumpKey = "idle_config" + DefaultProactiveConnCheckInterval = 3 * time.Second + idleConfigDumpKey = "idle_config" ) func getSharedTicker(p *LongPool, refreshInterval time.Duration) *utils.SharedTicker { @@ -73,6 +75,7 @@ type longConn struct { net.Conn sync.RWMutex deadline time.Time + pooledAt time.Time address string } @@ -105,12 +108,18 @@ type PoolDump struct { ConnsDeadline []time.Time `json:"conns_deadline"` } -func newPool(minIdle, maxIdle int, maxIdleTimeout time.Duration) *pool { +func newPool(config connpool.IdleConfig, proactiveCheckCfg ProactiveCheckConfig) *pool { p := &pool{ - idleList: make([]*longConn, 0, maxIdle), - minIdle: minIdle, - maxIdle: maxIdle, - maxIdleTimeout: maxIdleTimeout, + idleList: make([]*longConn, 0, config.MaxIdlePerAddress), + minIdle: config.MinIdlePerAddress, + maxIdle: config.MaxIdlePerAddress, + maxIdleTimeout: config.MaxIdleTimeout, + ProactiveCheckConfig: proactiveCheckCfg, + } + if p.ProactiveCheckConfig.Enable { + if p.maxIdleTimeout < p.ProactiveCheckConfig.Interval { + p.ProactiveCheckConfig.Interval = p.maxIdleTimeout + } } return p } @@ -123,6 +132,7 @@ type pool struct { minIdle int maxIdle int // currIdle <= maxIdle. maxIdleTimeout time.Duration // the idle connection will be cleaned if the idle time exceeds maxIdleTimeout. + ProactiveCheckConfig } // Get gets the first active connection from the idleList. Return the number of connections decreased during the Get. @@ -157,7 +167,9 @@ func (p *pool) Put(o *longConn) bool { p.mu.Lock() var recycled bool if len(p.idleList) < p.maxIdle { - o.deadline = time.Now().Add(p.maxIdleTimeout) + now := time.Now() + o.deadline = now.Add(p.maxIdleTimeout) + o.pooledAt = now p.idleList = append(p.idleList, o) recycled = true } @@ -169,6 +181,15 @@ func (p *pool) Put(o *longConn) bool { // Evict returns how many connections has been evicted. func (p *pool) Evict() (evicted int) { p.mu.Lock() + defer p.mu.Unlock() + + if p.ProactiveCheckConfig.Enable { + // connection state check, this will set the state to the closed connections + if err := p.checkConnState(); err != nil { + klog.Errorf("KITEX: connpool health check failed: %v", err) + } + } + nonIdle := len(p.idleList) - p.minIdle // clear non idle connections for ; evicted < nonIdle; evicted++ { @@ -182,10 +203,26 @@ func (p *pool) Evict() (evicted int) { p.idleList[evicted] = nil } p.idleList = p.idleList[evicted:] - p.mu.Unlock() return evicted } +// checkConnState checks and sets the state of connections that have been idle for more than connCheckInterval. +func (p *pool) checkConnState() error { + if connCheckFunc := p.ProactiveCheckConfig.CheckFunc; connCheckFunc != nil { + var toCheck []net.Conn + now := time.Now() + for _, conn := range p.idleList { + if !now.After(conn.pooledAt.Add(p.ProactiveCheckConfig.Interval)) { + // only check if now >= pooledAt+interval + break + } + toCheck = append(toCheck, conn.RawConn()) + } + return connCheckFunc(toCheck...) + } + return nil +} + // Len returns the length of the pool. func (p *pool) Len() int { p.mu.RLock() @@ -226,16 +263,15 @@ func (p *pool) Dump() PoolDump { func newPeer( serviceName string, addr net.Addr, - minIdle int, - maxIdle int, - maxIdleTimeout time.Duration, + idleCfg connpool.IdleConfig, + proactiveCheckCfg ProactiveCheckConfig, globalIdle *utils.MaxCounter, ) *peer { return &peer{ serviceName: serviceName, addr: addr, globalIdle: globalIdle, - pool: newPool(minIdle, maxIdle, maxIdleTimeout), + pool: newPool(idleCfg, proactiveCheckCfg), } } @@ -298,9 +334,54 @@ func (p *peer) Close() { p.globalIdle.DecN(int64(n)) } +type LongPoolConfig struct { + ServiceName string + connpool.IdleConfig + ProactiveCheckConfig +} + +// ProactiveCheckConfig is the config of proactive connection detection logic. +// only for go net now to avoid using closed connections. +type ProactiveCheckConfig struct { + Enable bool // if true, the connection pool will check the aliveness of conn. + Interval time.Duration + CheckFunc func(conn ...net.Conn) error // CheckFunc is used to detect the connection state +} + +// NewLongPoolWithConfig creates a long pool using the given LongPoolConfig. +func NewLongPoolWithConfig(cfg LongPoolConfig) *LongPool { + idleConfig := cfg.IdleConfig + limit := utils.NewMaxCounter(idleConfig.MaxIdleGlobal) + lp := &LongPool{ + reporter: &DummyReporter{}, + globalIdle: limit, + newPeer: func(addr net.Addr) *peer { + return newPeer( + cfg.ServiceName, + addr, + idleConfig, + cfg.ProactiveCheckConfig, + limit) + }, + config: cfg, + } + + evictInterval := idleConfig.MaxIdleTimeout + if cfg.ProactiveCheckConfig.Enable { + if interval := cfg.ProactiveCheckConfig.Interval; interval < evictInterval { + evictInterval = interval + } + } + + // add this long pool into the sharedTicker + lp.sharedTicker = getSharedTicker(lp, evictInterval) + return lp +} + // NewLongPool creates a long pool using the given IdleConfig. func NewLongPool(serviceName string, idlConfig connpool.IdleConfig) *LongPool { limit := utils.NewMaxCounter(idlConfig.MaxIdleGlobal) + pcfg := ProactiveCheckConfig{} lp := &LongPool{ reporter: &DummyReporter{}, globalIdle: limit, @@ -308,12 +389,15 @@ func NewLongPool(serviceName string, idlConfig connpool.IdleConfig) *LongPool { return newPeer( serviceName, addr, - idlConfig.MinIdlePerAddress, - idlConfig.MaxIdlePerAddress, - idlConfig.MaxIdleTimeout, + idlConfig, + pcfg, limit) }, - idleConfig: idlConfig, + config: LongPoolConfig{ + serviceName, + idlConfig, + pcfg, + }, } // add this long pool into the sharedTicker lp.sharedTicker = getSharedTicker(lp, idlConfig.MaxIdleTimeout) @@ -326,7 +410,7 @@ type LongPool struct { peerMap sync.Map newPeer func(net.Addr) *peer globalIdle *utils.MaxCounter - idleConfig connpool.IdleConfig + config LongPoolConfig sharedTicker *utils.SharedTicker closed int32 // active: 0, closed: 1 } @@ -377,7 +461,7 @@ func (lp *LongPool) Clean(network, address string) { // Dump is used to dump current long pool info when needed, like debug query. func (lp *LongPool) Dump() interface{} { m := make(map[string]interface{}) - m[configDumpKey] = lp.idleConfig + m[idleConfigDumpKey] = lp.config.IdleConfig lp.peerMap.Range(func(key, value interface{}) bool { t := value.(*peer).pool.Dump() m[key.(netAddr).String()] = t @@ -386,6 +470,11 @@ func (lp *LongPool) Dump() interface{} { return m } +// Config returns the config of the long pool +func (lp *LongPool) Config() LongPoolConfig { + return lp.config +} + // Close releases all peers in the pool, it is executed when client is closed. func (lp *LongPool) Close() error { if !atomic.CompareAndSwapInt32(&lp.closed, 0, 1) { diff --git a/pkg/remote/connpool/long_pool_test.go b/pkg/remote/connpool/long_pool_test.go index 34472aa725..1003361d39 100644 --- a/pkg/remote/connpool/long_pool_test.go +++ b/pkg/remote/connpool/long_pool_test.go @@ -22,19 +22,20 @@ import ( "fmt" "math/rand" "net" + "reflect" "runtime" "sync" "sync/atomic" "testing" "time" - "github.com/cloudwego/kitex/pkg/connpool" - "github.com/golang/mock/gomock" mocksnetpoll "github.com/cloudwego/kitex/internal/mocks/netpoll" mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" + internalRemote "github.com/cloudwego/kitex/internal/remote" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/connpool" dialer "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/utils" ) @@ -55,7 +56,7 @@ func TestPoolReuse(t *testing.T) { maxIdleTimeout = time.Millisecond ) - p := newPool(minIdle, maxIdle, maxIdleTimeout) + p := newPool(connpool.IdleConfig{MinIdlePerAddress: minIdle, MaxIdlePerAddress: maxIdle, MaxIdleTimeout: maxIdleTimeout}, ProactiveCheckConfig{}) count := make(map[*longConn]bool) conn := newLongConnForTest(ctrl, mockAddr0) @@ -82,7 +83,7 @@ func TestPoolGetInactiveConn(t *testing.T) { maxIdleTimeout = time.Millisecond ) - p := newPool(minIdle, maxIdle, maxIdleTimeout) + p := newPool(connpool.IdleConfig{MinIdlePerAddress: minIdle, MaxIdlePerAddress: maxIdle, MaxIdleTimeout: maxIdleTimeout}, ProactiveCheckConfig{}) // inactive conn var closed bool @@ -119,7 +120,7 @@ func TestPoolGetWithInactiveConn(t *testing.T) { inactiveNum = 5 ) - p := newPool(minIdle, maxIdle, maxIdleTimeout) + p := newPool(connpool.IdleConfig{MinIdlePerAddress: minIdle, MaxIdlePerAddress: maxIdle, MaxIdleTimeout: maxIdleTimeout}, ProactiveCheckConfig{}) // put active conn activeConn := newLongConnForTest(ctrl, mockAddr0) recycled := p.Put(activeConn) @@ -165,7 +166,7 @@ func TestPoolMaxIdle(t *testing.T) { maxIdleTimeout = time.Millisecond ) - p := newPool(minIdle, maxIdle, maxIdleTimeout) + p := newPool(connpool.IdleConfig{MinIdlePerAddress: minIdle, MaxIdlePerAddress: maxIdle, MaxIdleTimeout: maxIdleTimeout}, ProactiveCheckConfig{}) for i := 0; i < maxIdle+1; i++ { recycled := p.Put(newLongConnForTest(ctrl, mockAddr0)) if i < maxIdle { @@ -187,7 +188,7 @@ func TestPoolMinIdle(t *testing.T) { maxIdleTimeout = time.Millisecond ) - p := newPool(minIdle, maxIdle, maxIdleTimeout) + p := newPool(connpool.IdleConfig{MinIdlePerAddress: minIdle, MaxIdlePerAddress: maxIdle, MaxIdleTimeout: maxIdleTimeout}, ProactiveCheckConfig{}) for i := 0; i < maxIdle+1; i++ { p.Put(newLongConnForTest(ctrl, mockAddr0)) } @@ -208,7 +209,7 @@ func TestPoolClose(t *testing.T) { maxIdleTimeout = time.Millisecond ) - p := newPool(minIdle, maxIdle, maxIdleTimeout) + p := newPool(connpool.IdleConfig{MinIdlePerAddress: minIdle, MaxIdlePerAddress: maxIdle, MaxIdleTimeout: maxIdleTimeout}, ProactiveCheckConfig{}) for i := 0; i < maxIdle+1; i++ { p.Put(newLongConnForTest(ctrl, mockAddr0)) } @@ -229,7 +230,7 @@ func TestPoolDump(t *testing.T) { maxIdleTimeout = time.Millisecond ) - p := newPool(minIdle, maxIdle, maxIdleTimeout) + p := newPool(connpool.IdleConfig{MinIdlePerAddress: minIdle, MaxIdlePerAddress: maxIdle, MaxIdleTimeout: maxIdleTimeout}, ProactiveCheckConfig{}) for i := 0; i < maxIdle+1; i++ { p.Put(newLongConnForTest(ctrl, mockAddr0)) } @@ -914,6 +915,42 @@ func TestLongConnPoolDump(t *testing.T) { test.Assert(t, length == 1) } +func TestLongConnPoolProactiveCheck(t *testing.T) { + idleCfg := connpool.IdleConfig{MaxIdleTimeout: DefaultProactiveConnCheckInterval * 2} + proactiveCheckConfig := ProactiveCheckConfig{ + Enable: true, + Interval: DefaultProactiveConnCheckInterval, + CheckFunc: internalRemote.ConnectionStateCheck, + } + lp := NewLongPoolWithConfig(LongPoolConfig{ + ServiceName: mockDestService, + IdleConfig: idleCfg, + ProactiveCheckConfig: proactiveCheckConfig, + }) + test.Assert(t, lp.sharedTicker.Interval == DefaultProactiveConnCheckInterval) + lp.Close() + p := newPool(idleCfg, proactiveCheckConfig) + test.Assert(t, p.ProactiveCheckConfig.Enable) + test.Assert(t, p.ProactiveCheckConfig.Interval == DefaultProactiveConnCheckInterval) + test.Assert(t, p.ProactiveCheckConfig.CheckFunc != nil) + test.Assert(t, reflect.ValueOf(p.ProactiveCheckConfig.CheckFunc).Pointer() == reflect.ValueOf(internalRemote.ConnectionStateCheck).Pointer()) + // check conn state + err := p.ProactiveCheckConfig.CheckFunc() + test.Assert(t, err == nil) + + // adjust interval + idleCfg.MaxIdleTimeout = DefaultProactiveConnCheckInterval / 2 + lp = NewLongPoolWithConfig(LongPoolConfig{ + ServiceName: mockDestService, + IdleConfig: idleCfg, + ProactiveCheckConfig: proactiveCheckConfig, + }) + test.Assert(t, lp.sharedTicker.Interval == idleCfg.MaxIdleTimeout) + lp.Close() + p = newPool(idleCfg, proactiveCheckConfig) + test.Assert(t, p.ProactiveCheckConfig.Interval == idleCfg.MaxIdleTimeout) +} + func BenchmarkLongPoolGetOne(b *testing.B) { ctrl := gomock.NewController(b) defer ctrl.Finish() diff --git a/pkg/remote/trans/gonet/conn.go b/pkg/remote/trans/gonet/conn.go index cc2583fb1b..b2260229e5 100644 --- a/pkg/remote/trans/gonet/conn.go +++ b/pkg/remote/trans/gonet/conn.go @@ -18,16 +18,19 @@ package gonet import ( "errors" + "fmt" "net" "sync/atomic" + "syscall" "github.com/cloudwego/gopkg/bufiox" ) var ( _ bufioxReadWriter = &cliConn{} + _ syscall.Conn = &cliConn{} _ bufioxReadWriter = &svrConn{} - errConnClosed error = errors.New("connection has been closed") + errConnClosed = errors.New("connection has been closed") ) type bufioxReadWriter interface { @@ -36,12 +39,12 @@ type bufioxReadWriter interface { } // cliConn implements the net.Conn interface. -// FIXME: add proactive state check of long connection type cliConn struct { net.Conn - r *bufiox.DefaultReader - w *bufiox.DefaultWriter - closed uint32 // 1: closed + r *bufiox.DefaultReader + w *bufiox.DefaultWriter + closed uint32 // 1: closed + inactive uint32 // 1: inactive, set by proactive check logic in connection pool } func newCliConn(conn net.Conn) *cliConn { @@ -64,6 +67,27 @@ func (c *cliConn) Read(b []byte) (int, error) { return c.r.Read(b) } +// IsActive is used to check if the connection is active. +func (c *cliConn) IsActive() bool { + return atomic.LoadUint32(&c.inactive) == 0 +} + +func (c *cliConn) SetConnState(inactive bool) { + if inactive { + atomic.StoreUint32(&c.inactive, 1) + } else { + atomic.StoreUint32(&c.inactive, 0) + } +} + +func (c *cliConn) SyscallConn() (syscall.RawConn, error) { + sc, ok := c.Conn.(syscall.Conn) + if !ok { + return nil, fmt.Errorf("conn is not a syscall.Conn, got %T", c.Conn) + } + return sc.SyscallConn() +} + func (c *cliConn) Close() error { if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { c.r.Release(nil) diff --git a/pkg/remote/trans/gonet/conn_test.go b/pkg/remote/trans/gonet/conn_test.go index 1acce0bc46..aa137013a2 100644 --- a/pkg/remote/trans/gonet/conn_test.go +++ b/pkg/remote/trans/gonet/conn_test.go @@ -20,6 +20,8 @@ import ( "sync/atomic" "testing" + internalRemote "github.com/cloudwego/kitex/internal/remote" + "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" ) @@ -49,3 +51,24 @@ func TestSvrConn(t *testing.T) { sc.Close() test.Assert(t, closeNum == 1) } + +func TestCliConn(t *testing.T) { + closeNum := 0 + c := newCliConn(mocks.Conn{CloseFunc: func() (e error) { + closeNum++ + return nil + }}) + test.Assert(t, c.IsActive()) + c.SetConnState(true) + test.Assert(t, !c.IsActive()) + + // check + err := internalRemote.ConnectionStateCheck(c) + test.Assert(t, err != nil) // mocks.Conn does not implement syscall.Conn + + // close + c.Close() + test.Assert(t, atomic.LoadUint32(&c.closed) == 1) + c.Close() + test.Assert(t, closeNum == 1) +} diff --git a/pkg/remote/trans/gonet/dialer.go b/pkg/remote/trans/gonet/dialer.go index fe06331ce7..631736f0fe 100644 --- a/pkg/remote/trans/gonet/dialer.go +++ b/pkg/remote/trans/gonet/dialer.go @@ -21,9 +21,12 @@ import ( "net" "time" + internalRemote "github.com/cloudwego/kitex/internal/remote" "github.com/cloudwego/kitex/pkg/remote" ) +var _ internalRemote.IsGonetDialer = &dialer{} + // NewDialer returns the default go net dialer. func NewDialer() remote.Dialer { return &dialer{} @@ -33,6 +36,10 @@ type dialer struct { net.Dialer } +func (d *dialer) IsGonetDialer() bool { + return true +} + func (d *dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel()