diff --git a/integration/integration_test.go b/integration/integration_test.go index 7ab6b89..c02d4c2 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -342,6 +342,21 @@ func TestMain(t *testing.T) { assertBufEquals(t, res.stdout, data) }) + t.Run("timeout", func(t *testing.T) { + server := startServer(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + return + case <-time.After(time.Second): + } + }) + defer server.Close() + + res := runFetch(t, fetchPath, server.URL, "-t", "0.0000001") + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "request timed out after 100ns") + }) + t.Run("ignore status", func(t *testing.T) { var statusCode atomic.Int64 statusCode.Store(200) diff --git a/internal/client/client.go b/internal/client/client.go index c6704b8..7b4eaed 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -33,7 +33,6 @@ type ClientConfig struct { HTTP HTTPVersion Insecure bool Proxy *url.URL - Timeout time.Duration TLS uint16 } @@ -84,7 +83,6 @@ func NewClient(cfg ClientConfig) *Client { return &Client{ c: &http.Client{ - Timeout: cfg.Timeout, Transport: transport, }, } diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go index d1eaa63..3ebcd08 100644 --- a/internal/fetch/fetch.go +++ b/internal/fetch/fetch.go @@ -97,20 +97,11 @@ func Fetch(ctx context.Context, r *Request) int { } func fetch(ctx context.Context, r *Request) (int, error) { - errPrinter := r.PrinterHandle.Stderr() - outPrinter := r.PrinterHandle.Stdout() - - if r.URL.Scheme == "" { - // Use HTTPS if no scheme is defined. - r.URL.Scheme = "https" - } - c := client.NewClient(client.ClientConfig{ DNSServer: r.DNSServer, HTTP: r.HTTP, Insecure: r.Insecure, Proxy: r.Proxy, - Timeout: r.Timeout, TLS: r.TLS, }) req, err := c.NewRequest(ctx, client.RequestConfig{ @@ -159,6 +150,7 @@ func fetch(ctx context.Context, r *Request) (int, error) { } if r.Verbosity >= VExtraVerbose || r.DryRun { + errPrinter := r.PrinterHandle.Stderr() printRequestMetadata(errPrinter, req) if r.DryRun { @@ -178,6 +170,18 @@ func fetch(ctx context.Context, r *Request) (int, error) { errPrinter.Flush() } + if r.Timeout > 0 { + var cancel context.CancelFunc + cause := vars.ErrRequestTimedOut{Timeout: r.Timeout} + ctx, cancel = context.WithTimeoutCause(req.Context(), r.Timeout, cause) + defer cancel() + req = req.WithContext(ctx) + } + + return makeRequest(r, c, req) +} + +func makeRequest(r *Request, c *client.Client, req *http.Request) (int, error) { resp, err := c.Do(req) if err != nil { return 0, err @@ -190,17 +194,19 @@ func fetch(ctx context.Context, r *Request) (int, error) { } if r.Verbosity >= VNormal { - printResponseMetadata(errPrinter, r.Verbosity, resp) - errPrinter.Flush() + p := r.PrinterHandle.Stderr() + printResponseMetadata(p, r.Verbosity, resp) + p.Flush() } - body, err := formatResponse(r, resp, outPrinter) + body, err := formatResponse(r, resp, r.PrinterHandle.Stdout()) if err != nil { return 0, err } if body != nil { - err = streamToStdout(body, errPrinter, r.Output == "-", r.NoPager) + p := r.PrinterHandle.Stderr() + err = streamToStdout(body, p, r.Output == "-", r.NoPager) if err != nil { return 0, err } diff --git a/internal/update/update.go b/internal/update/update.go index ffbf63a..4b68e6e 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -40,8 +40,14 @@ func Update(ctx context.Context, p *printer.Printer, timeout time.Duration, sile } func update(ctx context.Context, p *printer.Printer, timeout time.Duration, silent bool) error { - cfg := client.ClientConfig{Timeout: timeout} - c := client.NewClient(cfg) + c := client.NewClient(client.ClientConfig{}) + + if timeout > 0 { + var cancel context.CancelFunc + cause := vars.ErrRequestTimedOut{Timeout: timeout} + ctx, cancel = context.WithTimeoutCause(ctx, timeout, cause) + defer cancel() + } writeInfo(p, silent, "fetching latest release tag") latest, err := getLatestRelease(ctx, c) diff --git a/internal/vars/vars.go b/internal/vars/vars.go index 7fe06c9..cd78bc0 100644 --- a/internal/vars/vars.go +++ b/internal/vars/vars.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "runtime/debug" + "time" "golang.org/x/term" ) @@ -36,6 +37,14 @@ type KeyVal struct { Key, Val string } +type ErrRequestTimedOut struct { + Timeout time.Duration +} + +func (err ErrRequestTimedOut) Error() string { + return fmt.Sprintf("request timed out after %s", err.Timeout) +} + type SignalError string func (err SignalError) Error() string {