diff --git a/integration/integration_test.go b/integration/integration_test.go index c02d4c2..5334cec 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -387,6 +387,47 @@ func TestMain(t *testing.T) { assertExitCode(t, 0, res) }) + t.Run("redirects", func(t *testing.T) { + var empty string + var urlStr atomic.Pointer[string] + urlStr.Store(&empty) + + var count atomic.Int64 + server := startServer(func(w http.ResponseWriter, r *http.Request) { + if count.Add(-1) < 0 { + w.WriteHeader(200) + return + } + + w.Header().Set("Location", *urlStr.Load()) + w.WriteHeader(301) + }) + defer server.Close() + urlStr.Store(&server.URL) + + // Success with no redirects. + res := runFetch(t, fetchPath, server.URL, "--redirects", "0") + assertExitCode(t, 0, res) + + // Returns 301 with no redirects. + count.Store(1) + res = runFetch(t, fetchPath, server.URL, "--redirects", "0") + assertExitCode(t, 0, res) + assertBufContains(t, res.stderr, "301 Moved Permanently") + + // Returns 200 with redirects. + count.Store(5) + res = runFetch(t, fetchPath, server.URL) + assertExitCode(t, 0, res) + assertBufContains(t, res.stderr, "200 OK") + + // Returns an error when max redirects exceeded. + count.Store(2) + res = runFetch(t, fetchPath, server.URL, "--redirects", "1") + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "exceeded maximum number of redirects") + }) + t.Run("server sent events", func(t *testing.T) { server := startServer(func(w http.ResponseWriter, r *http.Request) { const data = ":comment\n\ndata:{\"key\":\"val\"}\n\nevent:ev1\ndata: this is my data\n\n" diff --git a/internal/cli/app.go b/internal/cli/app.go index f4c5079..ecedce9 100644 --- a/internal/cli/app.go +++ b/internal/cli/app.go @@ -44,6 +44,7 @@ type App struct { Output string Proxy *url.URL QueryParams []core.KeyVal + Redirects *int Silent bool Timeout time.Duration TLS uint16 @@ -568,6 +569,25 @@ func (a *App) CLI() *CLI { return nil }, }, + { + Short: "", + Long: "redirects", + Args: "NUM", + Description: "Maximum number of redirects", + Default: "", + IsSet: func() bool { + return a.Redirects != nil + }, + Fn: func(value string) error { + n, err := strconv.Atoi(value) + if err != nil || n < 0 { + const usage = "must be a positive integer" + return flagValueError("redirects", value, usage) + } + a.Redirects = &n + return nil + }, + }, { Short: "s", Long: "silent", diff --git a/internal/client/client.go b/internal/client/client.go index 09ad6c8..ce185d0 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -4,6 +4,7 @@ import ( "compress/gzip" "context" "crypto/tls" + "fmt" "io" "net" "net/http" @@ -27,6 +28,7 @@ type ClientConfig struct { HTTP core.HTTPVersion Insecure bool Proxy *url.URL + Redirects *int TLS uint16 } @@ -80,11 +82,22 @@ func NewClient(cfg ClientConfig) *Client { } transport.TLSClientConfig.MinVersion = cfg.TLS - return &Client{ - c: &http.Client{ - Transport: transport, - }, + // Optionally set the maximum number of redirects. + client := &http.Client{Transport: transport} + if cfg.Redirects != nil { + redirects := *cfg.Redirects + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if redirects == 0 { + return http.ErrUseLastResponse + } + if len(via) > redirects { + return fmt.Errorf("exceeded maximum number of redirects: %d", redirects) + } + return nil + } } + + return &Client{c: client} } // RequestConfig represents the configuration for creating an HTTP request. diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go index c964bfb..7119ede 100644 --- a/internal/fetch/fetch.go +++ b/internal/fetch/fetch.go @@ -45,6 +45,7 @@ type Request struct { NoPager bool Output string PrinterHandle *printer.Handle + Redirects *int TLS uint16 Verbosity core.Verbosity @@ -91,6 +92,7 @@ func fetch(ctx context.Context, r *Request) (int, error) { HTTP: r.HTTP, Insecure: r.Insecure, Proxy: r.Proxy, + Redirects: r.Redirects, TLS: r.TLS, }) req, err := c.NewRequest(ctx, client.RequestConfig{ diff --git a/main.go b/main.go index 5ca3e8b..5f1b21a 100644 --- a/main.go +++ b/main.go @@ -90,6 +90,7 @@ func main() { NoPager: app.NoPager, Output: app.Output, PrinterHandle: printerHandle, + Redirects: app.Redirects, TLS: app.TLS, Verbosity: verbosity,