Skip to content

Commit

Permalink
Log a 499 if the client abandons the request
Browse files Browse the repository at this point in the history
Previously a client-cancelled request was logged as a proxy error, which
is misleading.

The other situation in which requests may be cancelled is where we're
forced to terminate incomplete requests during draining. In this case,
we'll continue to respond with a 502, but add additional context to the
logs for clarity.
  • Loading branch information
kevinmcconnell committed Nov 14, 2024
1 parent 2f3f10b commit d69a4af
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
35 changes: 30 additions & 5 deletions internal/server/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import (
"time"
)

const (
HTTP_STATUS_CLIENT_DISCONNECTED = 499
)

var (
ErrorInvalidHostPattern = errors.New("invalid host pattern")
ErrorDraining = errors.New("target is draining")
Expand Down Expand Up @@ -43,7 +47,7 @@ func (ts TargetState) String() string {
}

type inflightRequest struct {
cancel context.CancelFunc
cancel context.CancelCauseFunc
hijacked bool
}

Expand Down Expand Up @@ -124,7 +128,7 @@ func (t *Target) StartRequest(req *http.Request) (*http.Request, error) {
return nil, ErrorDraining
}

ctx, cancel := context.WithCancel(req.Context())
ctx, cancel := context.WithCancelCause(req.Context())
req = req.WithContext(ctx)

inflightRequest := &inflightRequest{cancel: cancel}
Expand Down Expand Up @@ -162,7 +166,7 @@ func (t *Target) Drain(timeout time.Duration) {
// Cancel any hijacked requests immediately, as they may be long-running.
for _, inflight := range toCancel {
if inflight.hijacked {
inflight.cancel()
inflight.cancel(ErrorDraining)
}
}

Expand All @@ -177,7 +181,7 @@ WAIT_FOR_REQUESTS_TO_COMPLETE:

// Cancel any remaining requests.
for _, inflight := range toCancel {
inflight.cancel()
inflight.cancel(ErrorDraining)
}
}

Expand Down Expand Up @@ -299,6 +303,19 @@ func (t *Target) handleProxyError(w http.ResponseWriter, r *http.Request, err er
return
}

if t.isClientCancellation(err) {
// The client has disconnected so will not see the response, but we
// still want to set it for the sake of the logs.
w.WriteHeader(HTTP_STATUS_CLIENT_DISCONNECTED)
return
}

if t.isDraining(err) {
slog.Info("Request cancelled due to draining", "target", t.Target(), "path", r.URL.Path)
SetErrorResponse(w, r, http.StatusBadGateway, nil)
return
}

slog.Error("Error while proxying", "target", t.Target(), "path", r.URL.Path, "error", err)
SetErrorResponse(w, r, http.StatusBadGateway, nil)
}
Expand All @@ -316,6 +333,14 @@ func (t *Target) isGatewayTimeout(err error) bool {
return false
}

func (t *Target) isClientCancellation(err error) bool {
return errors.Is(err, context.Canceled)
}

func (t *Target) isDraining(err error) bool {
return errors.Is(err, ErrorDraining)
}

func (t *Target) updateState(state TargetState) TargetState {
t.inflightLock.Lock()
defer t.inflightLock.Unlock()
Expand All @@ -339,7 +364,7 @@ func (t *Target) endInflightRequest(req *http.Request) {

inflightRequest, ok := t.inflight[req]
if ok {
inflightRequest.cancel()
inflightRequest.cancel(nil)
delete(t.inflight, req)
}
}
Expand Down
20 changes: 19 additions & 1 deletion internal/server/target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,21 @@ func TestTarget_ServeWebSocket(t *testing.T) {
})
}

func TestTarget_CancelledRequestsHaveStatus499(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
w := httptest.NewRecorder()

target := testTarget(t, func(w http.ResponseWriter, r *http.Request) {
cancel()
})

testServeRequestWithTarget(t, target, w, req)

require.Equal(t, 499, w.Result().StatusCode)
require.Empty(t, string(w.Body.String()))
}

func TestTarget_PreserveTargetHeader(t *testing.T) {
var requestTarget string

Expand Down Expand Up @@ -339,7 +354,10 @@ func TestTarget_DrainRequestsThatNeedToBeCancelled(t *testing.T) {
for i := 0; i < n; i++ {
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
go testServeRequestWithTarget(t, target, w, req)
go func() {
testServeRequestWithTarget(t, target, w, req)
assert.Equal(t, http.StatusBadGateway, w.Result().StatusCode)
}()
}

started.Wait()
Expand Down

0 comments on commit d69a4af

Please sign in to comment.