diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go index 7e4b296e634..e227f3a9038 100644 --- a/go/pools/smartconnpool/pool.go +++ b/go/pools/smartconnpool/pool.go @@ -128,7 +128,7 @@ type ConnPool[C Connection] struct { // workers is a waitgroup for all the currently running worker goroutines workers sync.WaitGroup - close chan struct{} + close atomic.Pointer[chan struct{}] capacityMu sync.Mutex config struct { @@ -193,14 +193,19 @@ func (pool *ConnPool[C]) runWorker(close <-chan struct{}, interval time.Duration } func (pool *ConnPool[C]) open() { - pool.close = make(chan struct{}) + closeChan := make(chan struct{}) + if !pool.close.CompareAndSwap(nil, &closeChan) { + // already open + return + } + pool.capacity.Store(pool.config.maxCapacity) pool.setIdleCount() // The expire worker takes care of removing from the waiter list any clients whose // context has been cancelled. - pool.runWorker(pool.close, 100*time.Millisecond, func(_ time.Time) bool { - maybeStarving := pool.wait.expire(false) + pool.runWorker(closeChan, 100*time.Millisecond, func(_ time.Time) bool { + maybeStarving := pool.wait.maybeStarvingCount() // Do not allow connections to starve; if there's waiters in the queue // and connections in the stack, it means we could be starving them. @@ -213,7 +218,7 @@ func (pool *ConnPool[C]) open() { idleTimeout := pool.IdleTimeout() if idleTimeout != 0 { // The idle worker takes care of closing connections that have been idle too long - pool.runWorker(pool.close, idleTimeout/10, func(now time.Time) bool { + pool.runWorker(closeChan, idleTimeout/10, func(now time.Time) bool { pool.closeIdleResources(now) return true }) @@ -224,7 +229,7 @@ func (pool *ConnPool[C]) open() { // The refresh worker periodically checks the refresh callback in this pool // to decide whether all the connections in the pool need to be cycled // (this usually only happens when there's a global DNS change). - pool.runWorker(pool.close, refreshInterval, func(_ time.Time) bool { + pool.runWorker(closeChan, refreshInterval, func(_ time.Time) bool { refresh, err := pool.config.refresh() if err != nil { log.Error(err) @@ -241,7 +246,7 @@ func (pool *ConnPool[C]) open() { // Open starts the background workers that manage the pool and gets it ready // to start serving out connections. func (pool *ConnPool[C]) Open(connect Connector[C], refresh RefreshCheck) *ConnPool[C] { - if pool.close != nil { + if pool.close.Load() != nil { // already open return pool } @@ -270,7 +275,7 @@ func (pool *ConnPool[C]) CloseWithContext(ctx context.Context) error { pool.capacityMu.Lock() defer pool.capacityMu.Unlock() - if pool.close == nil || pool.capacity.Load() == 0 { + if pool.close.Load() == nil || pool.capacity.Load() == 0 { // already closed return nil } @@ -280,9 +285,10 @@ func (pool *ConnPool[C]) CloseWithContext(ctx context.Context) error { // for the pool err := pool.setCapacity(ctx, 0) - close(pool.close) + closeChan := *pool.close.Swap(nil) + close(closeChan) + pool.workers.Wait() - pool.close = nil return err } @@ -312,7 +318,7 @@ func (pool *ConnPool[C]) reopen() { // IsOpen returns whether the pool is open func (pool *ConnPool[C]) IsOpen() bool { - return pool.close != nil + return pool.close.Load() != nil } // Capacity returns the maximum amount of connections that this pool can maintain open @@ -430,6 +436,7 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool { if pool.wait.tryReturnConn(conn) { return true } + if pool.closeOnIdleLimitReached(conn) { return false } @@ -595,7 +602,13 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { // to other clients, wait until one of the connections is returned if conn == nil { start := time.Now() - conn, err = pool.wait.waitForConn(ctx, nil) + + closeChan := pool.close.Load() + if closeChan == nil { + return nil, ErrConnPoolClosed + } + + conn, err = pool.wait.waitForConn(ctx, nil, *closeChan) if err != nil { return nil, ErrTimeout } @@ -652,7 +665,13 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( // wait for one of them if conn == nil { start := time.Now() - conn, err = pool.wait.waitForConn(ctx, setting) + + closeChan := pool.close.Load() + if closeChan == nil { + return nil, ErrConnPoolClosed + } + + conn, err = pool.wait.waitForConn(ctx, setting, *closeChan) if err != nil { return nil, ErrTimeout } @@ -729,11 +748,6 @@ func (pool *ConnPool[C]) setCapacity(ctx context.Context, newcap int64) error { "timed out while waiting for connections to be returned to the pool (capacity=%d, active=%d, borrowed=%d)", pool.capacity.Load(), pool.active.Load(), pool.borrowed.Load()) } - // if we're closing down the pool, make sure there's no clients waiting - // for connections because they won't be returned in the future - if newcap == 0 { - pool.wait.expire(true) - } // try closing from connections which are currently idle in the stacks conn := pool.getFromSettingsStack(nil) diff --git a/go/pools/smartconnpool/pool_test.go b/go/pools/smartconnpool/pool_test.go index ababeeae0d4..a165bc4db20 100644 --- a/go/pools/smartconnpool/pool_test.go +++ b/go/pools/smartconnpool/pool_test.go @@ -1241,3 +1241,84 @@ func TestGetSpike(t *testing.T) { close(errs) } } + +// TestCloseDuringWaitForConn confirms that we do not get hung when the pool gets +// closed while we are waiting for a connection from it. +func TestCloseDuringWaitForConn(t *testing.T) { + ctx := context.Background() + goRoutineCnt := 50 + getTimeout := 2000 * time.Millisecond + + for range 50 { + hung := make(chan (struct{}), goRoutineCnt) + var state TestState + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + MaxIdleCount: 1, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + closed := atomic.Bool{} + wg := sync.WaitGroup{} + var count atomic.Int64 + + fmt.Println("Starting TestCloseDuringWaitForConn") + + // Spawn multiple goroutines to perform Get and Put operations, but only + // allow connections to be checked out until `closed` has been set to true. + for range goRoutineCnt { + wg.Add(1) + go func() { + defer wg.Done() + for !closed.Load() { + timeout := time.After(getTimeout) + getCtx, getCancel := context.WithTimeout(ctx, getTimeout/3) + defer getCancel() + done := make(chan struct{}) + go func() { + defer close(done) + r, err := p.Get(getCtx, nil) + if err != nil { + return + } + count.Add(1) + r.Recycle() + }() + select { + case <-timeout: + hung <- struct{}{} + return + case <-done: + } + } + }() + } + + // Let the go-routines get up and running. + for count.Load() < 5000 { + time.Sleep(1 * time.Millisecond) + } + + // Close the pool, which should allow all goroutines to finish. + closeCtx, closeCancel := context.WithTimeout(ctx, 1*time.Second) + defer closeCancel() + err := p.CloseWithContext(closeCtx) + closed.Store(true) + require.NoError(t, err, "Failed to close pool") + + // Wait for all goroutines to finish. + wg.Wait() + select { + case <-hung: + require.FailNow(t, "Race encountered and deadlock detected") + default: + } + + fmt.Println("Count of connections checked out:", count.Load()) + // Check that the pool is closed and no connections are available. + require.EqualValues(t, 0, p.Capacity()) + require.EqualValues(t, 0, p.Available()) + require.EqualValues(t, 0, state.open.Load()) + } +} diff --git a/go/pools/smartconnpool/sema.s b/go/pools/smartconnpool/sema.s deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/go/pools/smartconnpool/sema_norace.go b/go/pools/smartconnpool/sema_norace.go deleted file mode 100644 index 63afe8082c1..00000000000 --- a/go/pools/smartconnpool/sema_norace.go +++ /dev/null @@ -1,40 +0,0 @@ -//go:build !race - -/* -Copyright 2023 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package smartconnpool - -import _ "unsafe" - -//go:linkname sync_runtime_Semacquire sync.runtime_Semacquire -func sync_runtime_Semacquire(addr *uint32) - -//go:linkname sync_runtime_Semrelease sync.runtime_Semrelease -func sync_runtime_Semrelease(addr *uint32, handoff bool, skipframes int) - -// semaphore is a single-use synchronization primitive that allows a Goroutine -// to wait until signaled. We use the Go runtime's internal implementation. -type semaphore struct { - f uint32 -} - -func (s *semaphore) wait() { - sync_runtime_Semacquire(&s.f) -} -func (s *semaphore) notify(handoff bool) { - sync_runtime_Semrelease(&s.f, handoff, 0) -} diff --git a/go/pools/smartconnpool/sema_race.go b/go/pools/smartconnpool/sema_race.go deleted file mode 100644 index a31cfaa85c5..00000000000 --- a/go/pools/smartconnpool/sema_race.go +++ /dev/null @@ -1,42 +0,0 @@ -//go:build race - -/* -Copyright 2023 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package smartconnpool - -import ( - "sync/atomic" - "time" -) - -// semaphore is a slow implementation of a single-use synchronization primitive. -// We use this inefficient implementation when running under the race detector -// because the detector doesn't understand the synchronization performed by the -// runtime's semaphore. -type semaphore struct { - b atomic.Bool -} - -func (s *semaphore) wait() { - for !s.b.CompareAndSwap(true, false) { - time.Sleep(time.Millisecond) - } -} - -func (s *semaphore) notify(_ bool) { - s.b.Store(true) -} diff --git a/go/pools/smartconnpool/waitlist.go b/go/pools/smartconnpool/waitlist.go index ef1eb1fe997..93f391b6d53 100644 --- a/go/pools/smartconnpool/waitlist.go +++ b/go/pools/smartconnpool/waitlist.go @@ -18,6 +18,7 @@ package smartconnpool import ( "context" + "runtime" "sync" "vitess.io/vitess/go/list" @@ -28,13 +29,8 @@ type waiter[C Connection] struct { // setting is the connection Setting that we'd like, or nil if we'd like a // a connection with no Setting applied setting *Setting - // conn will be set by another client to hand over the connection to use - conn *Pooled[C] - // ctx is the context of the waiting client to check for expiration - ctx context.Context - // sema is a synchronization primitive that allows us to block until our request - // has been fulfilled - sema semaphore + // conn is a channel that will receive the connection when it's ready + conn chan *Pooled[C] // age is the amount of cycles this client has been on the waitlist age uint32 } @@ -50,61 +46,86 @@ type waitlist[C Connection] struct { // The returned connection may _not_ have the requested Setting. This function can // also return a `nil` connection even if our context has expired, if the pool has // forced an expiration of all waiters in the waitlist. -func (wl *waitlist[C]) waitForConn(ctx context.Context, setting *Setting) (*Pooled[C], error) { +func (wl *waitlist[C]) waitForConn(ctx context.Context, setting *Setting, closeChan <-chan struct{}) (*Pooled[C], error) { elem := wl.nodes.Get().(*list.Element[waiter[C]]) - elem.Value = waiter[C]{setting: setting, conn: nil, ctx: ctx} + defer wl.nodes.Put(elem) + + elem.Value = waiter[C]{conn: elem.Value.conn, setting: setting} wl.mu.Lock() // add ourselves as a waiter at the end of the waitlist wl.list.PushBackValue(elem) wl.mu.Unlock() - // block on our waiter's semaphore until somebody can hand over a connection to us - elem.Value.sema.wait() + select { + case <-closeChan: + // Pool was closed while we were waiting. + removed := false + + wl.mu.Lock() + // Try to find and remove ourselves from the list. + for e := wl.list.Front(); e != nil; e = e.Next() { + if e == elem { + wl.list.Remove(elem) + removed = true + break + } + } + wl.mu.Unlock() + + if removed { + return nil, ErrConnPoolClosed + } + + // if we weren't able to remove ourselves from the waitlist, it means + // another goroutine is trying to hand us a connection + return <-elem.Value.conn, nil + + case <-ctx.Done(): + // Context expired. We need to try to remove ourselves from the waitlist to + // prevent another goroutine from trying to hand us a connection later on. + removed := false + + wl.mu.Lock() + // Try to find and remove ourselves from the list. + for e := wl.list.Front(); e != nil; e = e.Next() { + if e == elem { + wl.list.Remove(elem) + removed = true + break + } + } + wl.mu.Unlock() + + if removed { + return nil, context.Cause(ctx) + } - // we're awake -- the conn in our waiter contains the connection that was handed - // over to us, or nothing if we've been waken up forcefully. save the conn before - // we return our waiter to the pool of waiters for reuse. - conn := elem.Value.conn - wl.nodes.Put(elem) + // if we weren't able to remove ourselves from the waitlist, it means + // another goroutine is trying to hand us a connection + return <-elem.Value.conn, nil - if conn != nil { + case conn := <-elem.Value.conn: return conn, nil } - return nil, ctx.Err() } -// expire removes and wakes any expired waiter in the waitlist. -// if force is true, it'll wake and remove all the waiters. -func (wl *waitlist[C]) expire(force bool) (maybeStarving int) { +func (wl *waitlist[C]) maybeStarvingCount() (maybeStarving int) { if wl.list.Len() == 0 { return } - var expired []*list.Element[waiter[C]] - wl.mu.Lock() + defer wl.mu.Unlock() + // iterate the waitlist looking for waiters with an expired Context, // or remove everything if force is true for e := wl.list.Front(); e != nil; e = e.Next() { - if force || e.Value.ctx.Err() != nil { - expired = append(expired, e) - continue - } if e.Value.age == 0 { maybeStarving++ } } - // remove the expired waiters from the waitlist after traversing it - for _, e := range expired { - wl.list.Remove(e) - } - wl.mu.Unlock() - // once all the expired waiters have been removed from the waitlist, wake them up one by one - for _, e := range expired { - e.Value.sema.notify(false) - } return } @@ -154,16 +175,19 @@ func (wl *waitlist[D]) tryReturnConnSlow(conn *Pooled[D]) bool { } // if we have a target to return the connection to, simply write the connection - // into the waiter and signal their semaphore. they'll wake up to pick up the - // connection. - target.Value.conn = conn - target.Value.sema.notify(true) + // into the waiter's channel. + target.Value.conn <- conn + // Allow the goroutine waiting on the channel to start running _now_. + runtime.Gosched() + return true } func (wl *waitlist[C]) init() { wl.nodes.New = func() any { - return &list.Element[waiter[C]]{} + return &list.Element[waiter[C]]{ + Value: waiter[C]{conn: make(chan *Pooled[C])}, + } } wl.list.Init() } diff --git a/go/pools/smartconnpool/waitlist_test.go b/go/pools/smartconnpool/waitlist_test.go index 1486aa989b6..aa3078c5457 100644 --- a/go/pools/smartconnpool/waitlist_test.go +++ b/go/pools/smartconnpool/waitlist_test.go @@ -26,31 +26,33 @@ import ( "github.com/stretchr/testify/require" ) -func TestWaitlistExpireWithMultipleWaiters(t *testing.T) { +func TestWaitlistPoolCloseWithMultipleWaiters(t *testing.T) { wait := waitlist[*TestConn]{} wait.init() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() + poolClose := make(chan struct{}) + waiterCount := 2 expireCount := atomic.Int32{} for i := 0; i < waiterCount; i++ { go func() { - _, err := wait.waitForConn(ctx, nil) + _, err := wait.waitForConn(ctx, nil, poolClose) + if err != nil { expireCount.Add(1) } }() } + close(poolClose) + // Wait for the context to expire <-ctx.Done() - // Expire the waiters - wait.expire(false) - // Wait for the notified goroutines to finish timeout := time.After(1 * time.Second) ticker := time.NewTicker(10 * time.Millisecond)