Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions go/pools/smartconnpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ type ConnPool[C Connection] struct {

// workers is a waitgroup for all the currently running worker goroutines
workers sync.WaitGroup
close chan struct{}
close atomic.Pointer[chan struct{}]
capacityMu sync.Mutex

config struct {
Expand Down Expand Up @@ -193,14 +193,19 @@ func (pool *ConnPool[C]) runWorker(close <-chan struct{}, interval time.Duration
}

func (pool *ConnPool[C]) open() {
pool.close = make(chan struct{})
closeChan := make(chan struct{})
if !pool.close.CompareAndSwap(nil, &closeChan) {
// already open
return
}

pool.capacity.Store(pool.config.maxCapacity)
pool.setIdleCount()

// The expire worker takes care of removing from the waiter list any clients whose
// context has been cancelled.
pool.runWorker(pool.close, 100*time.Millisecond, func(_ time.Time) bool {
maybeStarving := pool.wait.expire(false)
pool.runWorker(closeChan, 100*time.Millisecond, func(_ time.Time) bool {
maybeStarving := pool.wait.maybeStarvingCount()

// Do not allow connections to starve; if there's waiters in the queue
// and connections in the stack, it means we could be starving them.
Expand All @@ -213,7 +218,7 @@ func (pool *ConnPool[C]) open() {
idleTimeout := pool.IdleTimeout()
if idleTimeout != 0 {
// The idle worker takes care of closing connections that have been idle too long
pool.runWorker(pool.close, idleTimeout/10, func(now time.Time) bool {
pool.runWorker(closeChan, idleTimeout/10, func(now time.Time) bool {
pool.closeIdleResources(now)
return true
})
Expand All @@ -224,7 +229,7 @@ func (pool *ConnPool[C]) open() {
// The refresh worker periodically checks the refresh callback in this pool
// to decide whether all the connections in the pool need to be cycled
// (this usually only happens when there's a global DNS change).
pool.runWorker(pool.close, refreshInterval, func(_ time.Time) bool {
pool.runWorker(closeChan, refreshInterval, func(_ time.Time) bool {
refresh, err := pool.config.refresh()
if err != nil {
log.Error(err)
Expand All @@ -241,7 +246,7 @@ func (pool *ConnPool[C]) open() {
// Open starts the background workers that manage the pool and gets it ready
// to start serving out connections.
func (pool *ConnPool[C]) Open(connect Connector[C], refresh RefreshCheck) *ConnPool[C] {
if pool.close != nil {
if pool.close.Load() != nil {
// already open
return pool
}
Expand Down Expand Up @@ -270,7 +275,7 @@ func (pool *ConnPool[C]) CloseWithContext(ctx context.Context) error {
pool.capacityMu.Lock()
defer pool.capacityMu.Unlock()

if pool.close == nil || pool.capacity.Load() == 0 {
if pool.close.Load() == nil || pool.capacity.Load() == 0 {
// already closed
return nil
}
Expand All @@ -280,9 +285,10 @@ func (pool *ConnPool[C]) CloseWithContext(ctx context.Context) error {
// for the pool
err := pool.setCapacity(ctx, 0)

close(pool.close)
closeChan := *pool.close.Swap(nil)
close(closeChan)

pool.workers.Wait()
pool.close = nil
return err
}

Expand Down Expand Up @@ -312,7 +318,7 @@ func (pool *ConnPool[C]) reopen() {

// IsOpen returns whether the pool is open
func (pool *ConnPool[C]) IsOpen() bool {
return pool.close != nil
return pool.close.Load() != nil
}

// Capacity returns the maximum amount of connections that this pool can maintain open
Expand Down Expand Up @@ -430,6 +436,7 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool {
if pool.wait.tryReturnConn(conn) {
return true
}

if pool.closeOnIdleLimitReached(conn) {
return false
}
Expand Down Expand Up @@ -595,7 +602,13 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) {
// to other clients, wait until one of the connections is returned
if conn == nil {
start := time.Now()
conn, err = pool.wait.waitForConn(ctx, nil)

closeChan := pool.close.Load()
if closeChan == nil {
return nil, ErrConnPoolClosed
}

conn, err = pool.wait.waitForConn(ctx, nil, *closeChan)
if err != nil {
return nil, ErrTimeout
}
Expand Down Expand Up @@ -652,7 +665,13 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) (
// wait for one of them
if conn == nil {
start := time.Now()
conn, err = pool.wait.waitForConn(ctx, setting)

closeChan := pool.close.Load()
if closeChan == nil {
return nil, ErrConnPoolClosed
}

conn, err = pool.wait.waitForConn(ctx, setting, *closeChan)
if err != nil {
return nil, ErrTimeout
}
Expand Down Expand Up @@ -729,11 +748,6 @@ func (pool *ConnPool[C]) setCapacity(ctx context.Context, newcap int64) error {
"timed out while waiting for connections to be returned to the pool (capacity=%d, active=%d, borrowed=%d)",
pool.capacity.Load(), pool.active.Load(), pool.borrowed.Load())
}
// if we're closing down the pool, make sure there's no clients waiting
// for connections because they won't be returned in the future
if newcap == 0 {
pool.wait.expire(true)
}

// try closing from connections which are currently idle in the stacks
conn := pool.getFromSettingsStack(nil)
Expand Down
81 changes: 81 additions & 0 deletions go/pools/smartconnpool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,3 +1241,84 @@ func TestGetSpike(t *testing.T) {
close(errs)
}
}

// TestCloseDuringWaitForConn confirms that we do not get hung when the pool gets
// closed while we are waiting for a connection from it.
func TestCloseDuringWaitForConn(t *testing.T) {
ctx := context.Background()
goRoutineCnt := 50
getTimeout := 2000 * time.Millisecond

for range 50 {
hung := make(chan (struct{}), goRoutineCnt)
var state TestState
p := NewPool(&Config[*TestConn]{
Capacity: 1,
MaxIdleCount: 1,
IdleTimeout: time.Second,
LogWait: state.LogWait,
}).Open(newConnector(&state), nil)

closed := atomic.Bool{}
wg := sync.WaitGroup{}
var count atomic.Int64

fmt.Println("Starting TestCloseDuringWaitForConn")

// Spawn multiple goroutines to perform Get and Put operations, but only
// allow connections to be checked out until `closed` has been set to true.
for range goRoutineCnt {
wg.Add(1)
go func() {
defer wg.Done()
for !closed.Load() {
timeout := time.After(getTimeout)
getCtx, getCancel := context.WithTimeout(ctx, getTimeout/3)
defer getCancel()
done := make(chan struct{})
go func() {
defer close(done)
r, err := p.Get(getCtx, nil)
if err != nil {
return
}
count.Add(1)
r.Recycle()
}()
select {
case <-timeout:
hung <- struct{}{}
return
case <-done:
}
}
}()
}

// Let the go-routines get up and running.
for count.Load() < 5000 {
time.Sleep(1 * time.Millisecond)
}

// Close the pool, which should allow all goroutines to finish.
closeCtx, closeCancel := context.WithTimeout(ctx, 1*time.Second)
defer closeCancel()
err := p.CloseWithContext(closeCtx)
closed.Store(true)
require.NoError(t, err, "Failed to close pool")

// Wait for all goroutines to finish.
wg.Wait()
select {
case <-hung:
require.FailNow(t, "Race encountered and deadlock detected")
default:
}

fmt.Println("Count of connections checked out:", count.Load())
// Check that the pool is closed and no connections are available.
require.EqualValues(t, 0, p.Capacity())
require.EqualValues(t, 0, p.Available())
require.EqualValues(t, 0, state.open.Load())
}
}
Empty file removed go/pools/smartconnpool/sema.s
Empty file.
40 changes: 0 additions & 40 deletions go/pools/smartconnpool/sema_norace.go

This file was deleted.

42 changes: 0 additions & 42 deletions go/pools/smartconnpool/sema_race.go

This file was deleted.

Loading
Loading