Skip to content

Commit 2b581d6

Browse files
committed
ensure websockets persists until done on drain
add e2e for ws beyond queue drain; move sleep to appropriate loc add ref to go issue separate drain test
1 parent 6265a8e commit 2b581d6

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

pkg/queue/sharedmain/handlers.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"net"
2222
"net/http"
23+
"sync/atomic"
2324
"time"
2425

2526
"go.uber.org/zap"
@@ -43,6 +44,7 @@ func mainHandler(
4344
prober func() bool,
4445
stats *netstats.RequestStats,
4546
logger *zap.SugaredLogger,
47+
pendingRequests *atomic.Int32,
4648
) (http.Handler, *pkghandler.Drainer) {
4749
target := net.JoinHostPort("127.0.0.1", env.UserPort)
4850

@@ -86,6 +88,8 @@ func mainHandler(
8688

8789
composedHandler = withFullDuplex(composedHandler, env.EnableHTTPFullDuplex, logger)
8890

91+
composedHandler = withRequestCounter(composedHandler, pendingRequests)
92+
8993
drainer := &pkghandler.Drainer{
9094
QuietPeriod: drainSleepDuration,
9195
// Add Activator probe header to the drainer so it can handle probes directly from activator
@@ -100,6 +104,7 @@ func mainHandler(
100104
// Hence we need to have RequestLogHandler be the first one.
101105
composedHandler = requestLogHandler(logger, composedHandler, env)
102106
}
107+
103108
return composedHandler, drainer
104109
}
105110

@@ -139,3 +144,11 @@ func withFullDuplex(h http.Handler, enableFullDuplex bool, logger *zap.SugaredLo
139144
h.ServeHTTP(w, r)
140145
})
141146
}
147+
148+
func withRequestCounter(h http.Handler, pendingRequests *atomic.Int32) http.Handler {
149+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
150+
pendingRequests.Add(1)
151+
defer pendingRequests.Add(-1)
152+
h.ServeHTTP(w, r)
153+
})
154+
}

pkg/queue/sharedmain/main.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"net/http"
2525
"os"
2626
"strconv"
27+
"sync/atomic"
2728
"time"
2829

2930
"github.com/kelseyhightower/envconfig"
@@ -169,6 +170,8 @@ func Main(opts ...Option) error {
169170
d := Defaults{
170171
Ctx: signals.NewContext(),
171172
}
173+
pendingRequests := atomic.Int32{}
174+
pendingRequests.Store(0)
172175

173176
// Parse the environment.
174177
var env config
@@ -234,7 +237,7 @@ func Main(opts ...Option) error {
234237
// Enable TLS when certificate is mounted.
235238
tlsEnabled := exists(logger, certPath) && exists(logger, keyPath)
236239

237-
mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger)
240+
mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger, &pendingRequests)
238241
adminHandler := adminHandler(d.Ctx, logger, drainer)
239242

240243
// Enable TLS server when activator server certs are mounted.
@@ -303,9 +306,23 @@ func Main(opts ...Option) error {
303306
return err
304307
case <-d.Ctx.Done():
305308
logger.Info("Received TERM signal, attempting to gracefully shutdown servers.")
306-
logger.Infof("Sleeping %v to allow K8s propagation of non-ready state", drainSleepDuration)
307309
drainer.Drain()
308310

311+
// Wait on active requests to complete. This is done explictly
312+
// to avoid closing any connections which have been highjacked,
313+
// as in net/http `.Shutdown` would do so ungracefully.
314+
// See https://github.com/golang/go/issues/17721
315+
ticker := time.NewTicker(1 * time.Second)
316+
defer ticker.Stop()
317+
logger.Infof("Drain: waiting for %d pending requests to complete", pendingRequests.Load())
318+
WaitOnPendingRequests:
319+
for range ticker.C {
320+
if pendingRequests.Load() <= 0 {
321+
logger.Infof("Drain: all pending requests completed")
322+
break WaitOnPendingRequests
323+
}
324+
}
325+
309326
for name, srv := range httpServers {
310327
logger.Info("Shutting down server: ", name)
311328
if err := srv.Shutdown(context.Background()); err != nil {

test/e2e/websocket_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ func TestWebSocketWithTimeout(t *testing.T) {
322322
idleTimeoutSeconds: 10,
323323
delay: "20",
324324
expectError: true,
325+
}, {
326+
name: "websocket does not drop after queue drain is called at 30s",
327+
timeoutSeconds: 60,
328+
delay: "45",
329+
expectError: false,
325330
}}
326331
for _, tc := range testCases {
327332
t.Run(tc.name, func(t *testing.T) {
@@ -349,6 +354,44 @@ func TestWebSocketWithTimeout(t *testing.T) {
349354
}
350355
}
351356

357+
func TestWebSocketDrain(t *testing.T) {
358+
clients := Setup(t)
359+
360+
testCases := []struct {
361+
name string
362+
timeoutSeconds int64
363+
delay string
364+
expectError bool
365+
}{{
366+
name: "websocket does not drop after queue drain is called",
367+
timeoutSeconds: 60,
368+
delay: "45",
369+
expectError: false,
370+
}}
371+
for _, tc := range testCases {
372+
t.Run(tc.name, func(t *testing.T) {
373+
names := test.ResourceNames{
374+
Service: test.ObjectNameForTest(t),
375+
Image: wsServerTestImageName,
376+
}
377+
378+
// Clean up in both abnormal and normal exits.
379+
test.EnsureTearDown(t, clients, &names)
380+
381+
_, err := v1test.CreateServiceReady(t, clients, &names,
382+
rtesting.WithRevisionTimeoutSeconds(tc.timeoutSeconds),
383+
if err != nil {
384+
t.Fatal("Failed to create WebSocket server:", err)
385+
}
386+
// Validate the websocket connection.
387+
err = ValidateWebSocketConnection(t, clients, names, tc.delay)
388+
if (err == nil && tc.expectError) || (err != nil && !tc.expectError) {
389+
t.Error(err)
390+
}
391+
})
392+
}
393+
}
394+
352395
func abs(a int) int {
353396
if a < 0 {
354397
return -a

0 commit comments

Comments
 (0)