From d69a4af9e5f239b1f1c0ac4c0f2f487b3182673f Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 13 Nov 2024 17:02:35 +0000 Subject: [PATCH] Log a 499 if the client abandons the request Previously a client-cancelled request was logged as a proxy error, which is misleading. The other situation in which requests may be cancelled is where we're forced to terminate incomplete requests during draining. In this case, we'll continue to respond with a 502, but add additional context to the logs for clarity. --- internal/server/target.go | 35 +++++++++++++++++++++++++++++----- internal/server/target_test.go | 20 ++++++++++++++++++- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/internal/server/target.go b/internal/server/target.go index 03114ad..3399ae4 100644 --- a/internal/server/target.go +++ b/internal/server/target.go @@ -15,6 +15,10 @@ import ( "time" ) +const ( + HTTP_STATUS_CLIENT_DISCONNECTED = 499 +) + var ( ErrorInvalidHostPattern = errors.New("invalid host pattern") ErrorDraining = errors.New("target is draining") @@ -43,7 +47,7 @@ func (ts TargetState) String() string { } type inflightRequest struct { - cancel context.CancelFunc + cancel context.CancelCauseFunc hijacked bool } @@ -124,7 +128,7 @@ func (t *Target) StartRequest(req *http.Request) (*http.Request, error) { return nil, ErrorDraining } - ctx, cancel := context.WithCancel(req.Context()) + ctx, cancel := context.WithCancelCause(req.Context()) req = req.WithContext(ctx) inflightRequest := &inflightRequest{cancel: cancel} @@ -162,7 +166,7 @@ func (t *Target) Drain(timeout time.Duration) { // Cancel any hijacked requests immediately, as they may be long-running. for _, inflight := range toCancel { if inflight.hijacked { - inflight.cancel() + inflight.cancel(ErrorDraining) } } @@ -177,7 +181,7 @@ WAIT_FOR_REQUESTS_TO_COMPLETE: // Cancel any remaining requests. for _, inflight := range toCancel { - inflight.cancel() + inflight.cancel(ErrorDraining) } } @@ -299,6 +303,19 @@ func (t *Target) handleProxyError(w http.ResponseWriter, r *http.Request, err er return } + if t.isClientCancellation(err) { + // The client has disconnected so will not see the response, but we + // still want to set it for the sake of the logs. + w.WriteHeader(HTTP_STATUS_CLIENT_DISCONNECTED) + return + } + + if t.isDraining(err) { + slog.Info("Request cancelled due to draining", "target", t.Target(), "path", r.URL.Path) + SetErrorResponse(w, r, http.StatusBadGateway, nil) + return + } + slog.Error("Error while proxying", "target", t.Target(), "path", r.URL.Path, "error", err) SetErrorResponse(w, r, http.StatusBadGateway, nil) } @@ -316,6 +333,14 @@ func (t *Target) isGatewayTimeout(err error) bool { return false } +func (t *Target) isClientCancellation(err error) bool { + return errors.Is(err, context.Canceled) +} + +func (t *Target) isDraining(err error) bool { + return errors.Is(err, ErrorDraining) +} + func (t *Target) updateState(state TargetState) TargetState { t.inflightLock.Lock() defer t.inflightLock.Unlock() @@ -339,7 +364,7 @@ func (t *Target) endInflightRequest(req *http.Request) { inflightRequest, ok := t.inflight[req] if ok { - inflightRequest.cancel() + inflightRequest.cancel(nil) delete(t.inflight, req) } } diff --git a/internal/server/target_test.go b/internal/server/target_test.go index a452f5d..a7e68bd 100644 --- a/internal/server/target_test.go +++ b/internal/server/target_test.go @@ -146,6 +146,21 @@ func TestTarget_ServeWebSocket(t *testing.T) { }) } +func TestTarget_CancelledRequestsHaveStatus499(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + target := testTarget(t, func(w http.ResponseWriter, r *http.Request) { + cancel() + }) + + testServeRequestWithTarget(t, target, w, req) + + require.Equal(t, 499, w.Result().StatusCode) + require.Empty(t, string(w.Body.String())) +} + func TestTarget_PreserveTargetHeader(t *testing.T) { var requestTarget string @@ -339,7 +354,10 @@ func TestTarget_DrainRequestsThatNeedToBeCancelled(t *testing.T) { for i := 0; i < n; i++ { req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() - go testServeRequestWithTarget(t, target, w, req) + go func() { + testServeRequestWithTarget(t, target, w, req) + assert.Equal(t, http.StatusBadGateway, w.Result().StatusCode) + }() } started.Wait()