Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,26 @@ This builds a static binary that can work inside containers.
- Type: `bool`
- Default: `false`

#### Example
#### Examples

Download and extract an archive:

pget https://storage.googleapis.com/replicant-misc/sd15.tar ./sd15 -x

This command will download Stable Diffusion 1.5 weights to the path ./sd15 with high concurrency. After the file is downloaded, it will be automatically extracted.

Download with authentication headers:

pget -H "Authorization: Bearer token123" https://api.example.com/file.tar ./file.tar

Download with multiple custom headers:

pget -H "Authorization: Bearer token123" -H "X-Custom-Header: value" https://api.example.com/file.tar ./file.tar

Use environment variable for headers:

PGET_HEADERS='{"Authorization":"Bearer token123"}' pget https://api.example.com/file.tar ./file.tar

### Multi-File Mode
pget multifile <manifest-file>

Expand Down Expand Up @@ -112,6 +126,11 @@ https://example.com/music.mp3 /local/path/to/music.mp3
- Force download, overwriting existing file
- Type: `bool`
- Default: `false`
- `-H`, `--header`
- HTTP headers to include in requests (format: 'Key: Value'), can be specified multiple times
- Type: `string slice`
- Example: `-H "Authorization: Bearer token123" -H "X-Custom-Header: value"`
- Environment variable: `PGET_HEADERS` (JSON map format: `{"Key":"Value"}`)
- `--log-level`
- Log level (debug, info, warn, error)
- Type: `string`
Expand Down
19 changes: 19 additions & 0 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,24 @@ func rootPersistentPreRunEFunc(cmd *cobra.Command, args []string) error {
viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor)
}

// Process headers from CLI flags
headerSlice := viper.GetStringSlice(config.OptHeader)
if len(headerSlice) > 0 {
headerMap, err := config.HeadersToMap(headerSlice)
if err != nil {
return fmt.Errorf("error parsing headers: %w", err)
}
// Merge with any existing headers from environment variable
existingHeaders := viper.GetStringMapString(config.OptHeaders)
if existingHeaders == nil {
existingHeaders = make(map[string]string)
}
for k, v := range headerMap {
existingHeaders[k] = v
}
viper.Set(config.OptHeaders, existingHeaders)
}

return nil
}

Expand All @@ -179,6 +197,7 @@ func persistentFlags(cmd *cobra.Command) error {
cmd.PersistentFlags().Int(config.OptMaxConnPerHost, 40, "Maximum number of (global) concurrent connections per host")
cmd.PersistentFlags().StringP(config.OptOutputConsumer, "o", "file", "Output Consumer (file, tar, null)")
cmd.PersistentFlags().String(config.OptPIDFile, defaultPidFilePath(), "PID file path")
cmd.PersistentFlags().StringSliceP(config.OptHeader, "H", []string{}, "HTTP headers to include in requests (format: 'Key: Value')")

if err := hideAndDeprecateFlags(cmd); err != nil {
return err
Expand Down
4 changes: 3 additions & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ type PGetHTTPClient struct {
}

func (c *PGetHTTPClient) Do(req *http.Request) (*http.Response, error) {
req.Header.Set("User-Agent", fmt.Sprintf("pget/%s", version.GetVersion()))
// Set custom headers first
for k, v := range c.headers {
req.Header.Set(k, v)
}
// Set User-Agent last to ensure it's always the pget user agent
req.Header.Set("User-Agent", fmt.Sprintf("pget/%s", version.GetVersion()))
return c.Client.Do(req)
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/replicate/pget/pkg/client"
"github.com/replicate/pget/pkg/config"
Expand Down Expand Up @@ -160,3 +163,45 @@ func TestRetryPolicy(t *testing.T) {
})
}
}

func TestPGetHTTPClient_Headers(t *testing.T) {
// Create a test server that echoes back the headers
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Write back the custom headers as response headers for verification
for key, values := range r.Header {
for _, value := range values {
w.Header().Add("Echo-"+key, value)
}
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

// Set up viper with custom headers
viper.Set(config.OptHeaders, map[string]string{
"Authorization": "Bearer test-token",
"X-Custom-Header": "custom-value",
})
defer viper.Reset()

// Create client
httpClient := client.NewHTTPClient(client.Options{
MaxRetries: 0,
})

// Make a request
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)

resp, err := httpClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify that our custom headers were sent
assert.Equal(t, "Bearer test-token", resp.Header.Get("Echo-Authorization"))
assert.Equal(t, "custom-value", resp.Header.Get("Echo-X-Custom-Header"))

// Verify that User-Agent is set and contains "pget"
userAgent := resp.Header.Get("Echo-User-Agent")
assert.Contains(t, userAgent, "pget/")
}
34 changes: 34 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,40 @@ func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error)
return resolveOverrideMap, nil
}

// HeadersToMap converts a slice of header strings in the format "Key: Value" to a map[string]string.
// It merges with any existing headers from the PGET_HEADERS environment variable.
func HeadersToMap(headerSlice []string) (map[string]string, error) {
logger := logging.GetLogger()
headerMap := make(map[string]string)

if len(headerSlice) == 0 {
return nil, nil
}

for _, header := range headerSlice {
// Split on the first colon to separate key and value
parts := strings.SplitN(header, ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid header format, expected 'Key: Value', got: %s", header)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])

if key == "" {
return nil, fmt.Errorf("header key cannot be empty in: %s", header)
}

headerMap[key] = value
}

if logger.GetLevel() == zerolog.DebugLevel {
for key, value := range headerMap {
logger.Debug().Str("header", key).Str("value", value).Msg("Header")
}
}
return headerMap, nil
}

// GetConsumer returns the consumer specified by the user on the command line
// or an error if the consumer is invalid. Note that this function explicitly
// calls viper.GetString(OptExtract) internally.
Expand Down
27 changes: 27 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,30 @@ func TestGetCacheSRV(t *testing.T) {
})
}
}

func TestHeadersToMap(t *testing.T) {
testCases := []struct {
name string
headers []string
expected map[string]string
err bool
}{
{"empty", []string{}, nil, false},
{"single", []string{"Authorization: Bearer token123"}, map[string]string{"Authorization": "Bearer token123"}, false},
{"multiple", []string{"Authorization: Bearer token123", "X-Custom-Header: value"}, map[string]string{"Authorization": "Bearer token123", "X-Custom-Header": "value"}, false},
{"with spaces", []string{"Content-Type: application/json"}, map[string]string{"Content-Type": "application/json"}, false},
{"value with colon", []string{"Authorization: Bearer: token:123"}, map[string]string{"Authorization": "Bearer: token:123"}, false},
{"trim spaces", []string{" Authorization : Bearer token123 "}, map[string]string{"Authorization": "Bearer token123"}, false},
{"invalid format no colon", []string{"InvalidHeader"}, nil, true},
{"invalid format empty key", []string{": value"}, nil, true},
{"empty value", []string{"X-Empty-Header:"}, map[string]string{"X-Empty-Header": ""}, false},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
headers, err := HeadersToMap(tc.headers)
assert.Equal(t, tc.err, err != nil)
assert.Equal(t, tc.expected, headers)
})
}
}
1 change: 1 addition & 0 deletions pkg/config/optnames.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
OptExtract = "extract"
OptForce = "force"
OptForceHTTP2 = "force-http2"
OptHeader = "header"
OptLoggingLevel = "log-level"
OptMaxChunks = "max-chunks"
OptMaxConnPerHost = "max-conn-per-host"
Expand Down