Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/aws/escape.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"unicode/utf8"
)

// escapeURIPath writes the path-escaped version of uri to w.
func escapeURIPath(w *bytes.Buffer, uri string) {
var n int
for i, c := range uri {
Expand All @@ -26,13 +27,14 @@ func escapeURIPath(w *bytes.Buffer, uri string) {

func encodeHexUpper(w *bytes.Buffer, s []byte) {
const hexUpper = "0123456789ABCDEF"
for i := 0; i < len(s); i++ {
for i := range s {
b := s[i]
w.WriteByte(hexUpper[b>>4])
w.WriteByte(hexUpper[b&0x0F])
}
}

// mapping of valid uri path bytes.
var validURIBytes = [256]bool{
// -
45: true,
Expand Down
24 changes: 22 additions & 2 deletions internal/aws/sigv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ type Config struct {
Service string
}

// Sign signs the provided HTTP request with the information from Config,
// returning any error encountered.
func Sign(req *http.Request, cfg Config, now time.Time) error {
datetime := now.Format(datetimeFormat)
req.Header.Set("X-Amz-Date", datetime)
Expand All @@ -39,12 +41,14 @@ func Sign(req *http.Request, cfg Config, now time.Time) error {
}
req.Header.Set(headerContentSha256, payload)

// Build the signature.
signedHeaders := getSignedHeaders(req)
canonicalRequest := buildCanonicalRequest(req, signedHeaders, payload)
stringToSign := buildStringToSign(datetime, cfg.Region, cfg.Service, canonicalRequest)
signingKey := createSigningKey(datetime[:8], cfg.Region, cfg.Service, cfg.SecretKey)
signature := hex.EncodeToString(hmacSha256(signingKey, stringToSign))

// Format the Authorization header value.
var sb strings.Builder
sb.Grow(512)

Expand All @@ -70,15 +74,19 @@ func Sign(req *http.Request, cfg Config, now time.Time) error {
return nil
}

// getPayloadHash returns the appropriate payload has for HTTP request and service.
func getPayloadHash(req *http.Request, service string) (string, error) {
// If a payload header already exists, use that.
if payload := req.Header.Get(headerContentSha256); payload != "" {
return payload, nil
}

// Use the empty sha256 if the request has no body.
if req.Body == nil || req.Body == http.NoBody {
return emptySha256, nil
}

// Attempt to utilize the GetBody function if it exists.
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
Expand All @@ -88,6 +96,8 @@ func getPayloadHash(req *http.Request, service string) (string, error) {
return hexSha256Reader(body)
}

// If body implements io.ReadSeeker, calculate the hash and seek back
// to the start afterwards.
if rs, ok := req.Body.(io.ReadSeeker); ok {
payload, err := hexSha256Reader(rs)
if err != nil {
Expand All @@ -99,36 +109,46 @@ func getPayloadHash(req *http.Request, service string) (string, error) {
return payload, nil
}

// At this point, if the service is S3, use the "UNISIGNED-PAYLOAD" to
// avoid having to read the entire request body into memory.
if service == "s3" {
return "UNSIGNED-PAYLOAD", nil
}

defer req.Body.Close()
body, err := io.ReadAll(req.Body)
// Read the entire body into memory to calculate the payload hash.
oldBody := req.Body
defer oldBody.Close()
body, err := io.ReadAll(oldBody)
if err != nil {
return "", err
}
req.Body = io.NopCloser(bytes.NewReader(body))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(body)), nil
}

return hexSha256Reader(bytes.NewReader(body))
}

func getSignedHeaders(req *http.Request) []core.KeyVal {
out := make([]core.KeyVal, 0, len(req.Header)+1)

// Host header is required to be signed.
if _, ok := req.Header["Host"]; !ok {
out = append(out, core.KeyVal{Key: "host", Val: req.URL.Host})
}

for key, vals := range req.Header {
switch key {
case "Authorization", "Content-Length", "User-Agent":
// Avoid signing these headers.
continue
}
key = strings.ToLower(strings.TrimSpace(key))
val := strings.TrimSpace(strings.Join(vals, ","))
out = append(out, core.KeyVal{Key: key, Val: val})
}
// Headers should be ordered by key.
slices.SortFunc(out, func(a, b core.KeyVal) int {
return strings.Compare(a.Key, b.Key)
})
Expand Down
10 changes: 5 additions & 5 deletions internal/cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (
"time"

"github.com/ryanfowler/fetch/internal/aws"
"github.com/ryanfowler/fetch/internal/client"
"github.com/ryanfowler/fetch/internal/core"
"github.com/ryanfowler/fetch/internal/printer"
)

// App represents the full configuration for a fetch invocation.
type App struct {
URL *url.URL

Expand All @@ -33,7 +33,7 @@ type App struct {
Format core.Format
Headers []core.KeyVal
Help bool
HTTP client.HTTPVersion
HTTP core.HTTPVersion
IgnoreStatus bool
Insecure bool
JSON bool
Expand Down Expand Up @@ -394,14 +394,14 @@ func (a *App) CLI() *CLI {
Default: "",
Values: []string{"1", "2"},
IsSet: func() bool {
return a.HTTP != client.HTTPDefault
return a.HTTP != core.HTTPDefault
},
Fn: func(value string) error {
switch value {
case "1":
a.HTTP = client.HTTP1
a.HTTP = core.HTTP1
case "2":
a.HTTP = client.HTTP2
a.HTTP = core.HTTP2
default:
const usage = "must be one of [1, 2]"
return flagValueError("http", value, usage)
Expand Down
Loading
Loading