Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmd/crane/cmd/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command {
var (
cachePath, format string
annotateRef bool
resumable bool
)

cmd := &cobra.Command{
Expand All @@ -49,6 +50,10 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command {
return fmt.Errorf("parsing reference %q: %w", src, err)
}

if resumable {
o.Remote = append(o.Remote, remote.WithResumable())
}

rmt, err := remote.Get(ref, o.Remote...)
if err != nil {
return err
Expand Down Expand Up @@ -133,6 +138,7 @@ func NewCmdPull(options *[]crane.Option) *cobra.Command {
cmd.Flags().StringVarP(&cachePath, "cache_path", "c", "", "Path to cache image layers")
cmd.Flags().StringVar(&format, "format", "tarball", fmt.Sprintf("Format in which to save images (%q, %q, or %q)", "tarball", "legacy", "oci"))
cmd.Flags().BoolVar(&annotateRef, "annotate-ref", false, "Preserves image reference used to pull as an annotation when used with --format=oci")
cmd.Flags().BoolVar(&resumable, "resumable", false, "Enable resumable transport for pulling images")

return cmd
}
44 changes: 44 additions & 0 deletions pkg/v1/remote/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package remote
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -747,3 +749,45 @@ func TestData(t *testing.T) {
t.Fatal(err)
}
}

func TestImageResumable(t *testing.T) {
ref, err := name.ParseReference("ghcr.io/labring/fastgpt:v4.9.0")
if err != nil {
t.Fatal(err)
}

image, err := Image(ref, WithResumable())
if err != nil {
t.Fatal(err)
}

layers, err := image.Layers()
if err != nil {
t.Fatal(err)
}

for _, layer := range layers {
digest, err := layer.Digest()
if err != nil {
t.Fatal(err)
}

rc, err := layer.Compressed()
if err != nil {
t.Fatal(err)
}

hash := sha256.New()
_, err = io.Copy(hash, rc)
rc.Close()
if err != nil {
t.Fatal(err)
}

if digest.Hex == hex.EncodeToString(hash.Sum(nil)) {
t.Logf("digest matches: %s", digest)
} else {
t.Errorf("digest mismatch: %s != %s", digest, hex.EncodeToString(hash.Sum(nil)))
}
}
}
13 changes: 13 additions & 0 deletions pkg/v1/remote/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type options struct {
retryBackoff Backoff
retryPredicate retry.Predicate
retryStatusCodes []int
resumable bool

// Only these options can overwrite Reuse()d options.
platform v1.Platform
Expand Down Expand Up @@ -170,6 +171,11 @@ func makeOptions(opts ...Option) (*options, error) {

// Wrap the transport in something that can retry network flakes.
o.transport = transport.NewRetry(o.transport, transport.WithRetryBackoff(o.retryBackoff), transport.WithRetryPredicate(predicate), transport.WithRetryStatusCodes(o.retryStatusCodes...))

if o.resumable {
o.transport = transport.NewResumable(o.transport)
}

// Wrap this last to prevent transport.New from double-wrapping.
if o.userAgent != "" {
o.transport = transport.NewUserAgent(o.transport, o.userAgent)
Expand All @@ -192,6 +198,13 @@ func WithTransport(t http.RoundTripper) Option {
}
}

func WithResumable() Option {
return func(o *options) error {
o.resumable = true
return nil
}
}

// WithAuth is a functional option for overriding the default authenticator
// for remote operations.
// It is an error to use both WithAuth and WithAuthFromKeychain in the same Option set.
Expand Down
242 changes: 242 additions & 0 deletions pkg/v1/remote/transport/resumable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package transport

import (
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"sync/atomic"

"github.com/google/go-containerregistry/pkg/logs"
)

// NewResumable creates a http.RoundTripper that resumes http GET from error,
// and the inner should be wrapped with retry transport, otherwise, the
// transport will abort if resume() returns error.
func NewResumable(inner http.RoundTripper) http.RoundTripper {
return &resumableTransport{inner: inner}
}

var (
contentRangeRe = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)$`)
)

type resumableTransport struct {
inner http.RoundTripper
}

func (rt *resumableTransport) RoundTrip(in *http.Request) (*http.Response, error) {
if in.Method != http.MethodGet {
return rt.inner.RoundTrip(in)
}

req := in.Clone(in.Context())
req.Header.Set("Range", "bytes=0-")
resp, err := rt.inner.RoundTrip(req)
if err != nil {
return resp, err
}

switch resp.StatusCode {
case http.StatusPartialContent:
case http.StatusRequestedRangeNotSatisfiable:
// fallback to previous behavior
resp.Body.Close()
return rt.inner.RoundTrip(in)
default:
return resp, nil
}

var contentLength int64
if _, _, contentLength, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil || contentLength <= 0 {
// fallback to previous behavior
resp.Body.Close()
return rt.inner.RoundTrip(in)
}

// modify response status to 200, ensure caller error checking works
resp.StatusCode = http.StatusOK
resp.Status = "200 OK"
resp.ContentLength = contentLength
resp.Body = &resumableBody{
rc: resp.Body,
inner: rt.inner,
req: req,
total: contentLength,
transferred: 0,
}

return resp, nil
}

type resumableBody struct {
rc io.ReadCloser

inner http.RoundTripper
req *http.Request

transferred int64
total int64

closed uint32
}

func (rb *resumableBody) Read(p []byte) (n int, err error) {
if atomic.LoadUint32(&rb.closed) == 1 {
// response body already closed
return 0, http.ErrBodyReadAfterClose
} else if rb.total >= 0 && rb.transferred >= rb.total {
return 0, io.EOF
}

resume:
if n, err = rb.rc.Read(p); n > 0 {
rb.transferred += int64(n)
}

if err == nil {
return
}

if errors.Is(err, io.EOF) && rb.total >= 0 && rb.transferred == rb.total {
return
}

if err = rb.resume(err); err == nil {
if n == 0 {
// zero bytes read, try reading again with new response.Body
goto resume
}

// already read some bytes from previous response.Body, returns and waits for next Read operation
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid labels and express this as a loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I'm not familiar with Github's systems. I'll raise this in a review.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(reopening)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid labels and express this as a loop.

modified


return n, err
}

func (rb *resumableBody) Close() (err error) {
if !atomic.CompareAndSwapUint32(&rb.closed, 0, 1) {
return nil
}

return rb.rc.Close()
}

func (rb *resumableBody) resume(reason error) error {
if reason != nil {
logs.Debug.Printf("Resume http transporting from error: %v", reason)
}

ctx := rb.req.Context()
select {
case <-ctx.Done():
// context already done, stop resuming from error
return ctx.Err()
default:
}

req := rb.req.Clone(ctx)
req.Header.Set("Range", "bytes="+strconv.FormatInt(rb.transferred, 10)+"-")
resp, err := rb.inner.RoundTrip(req)
if err != nil {
return err
}

if err = rb.validate(resp); err != nil {
resp.Body.Close()
return err
}

if atomic.LoadUint32(&rb.closed) == 1 {
resp.Body.Close()
return http.ErrBodyReadAfterClose
}

rb.rc.Close()
rb.rc = resp.Body

return nil
}

func (rb *resumableBody) validate(resp *http.Response) (err error) {
var start, total int64
switch resp.StatusCode {
case http.StatusPartialContent:
if start, _, total, err = parseContentRange(resp.Header.Get("Content-Range")); err != nil {
return err
}

if total > rb.total {
rb.total = total
}

if start == rb.transferred {
break
} else if start < rb.transferred {
if _, err := io.CopyN(io.Discard, resp.Body, rb.transferred-start); err != nil {
return fmt.Errorf("discard overlapped data failed, %v", err)
}
} else {
return fmt.Errorf("unexpected resume start %d, wanted: %d", start, rb.transferred)
}
case http.StatusOK:
if rb.transferred > 0 {
if _, err = io.CopyN(io.Discard, resp.Body, rb.transferred); err != nil {
return err
}
}
case http.StatusRequestedRangeNotSatisfiable:
if contentRange := resp.Header.Get("Content-Range"); contentRange != "" && strings.HasPrefix(contentRange, "bytes */") {
if total, err = strconv.ParseInt(strings.TrimPrefix(contentRange, "bytes */"), 10, 64); err == nil && total >= 0 && rb.transferred >= total {
return io.EOF
}
}

fallthrough
default:
return fmt.Errorf("unexpected status code %d", resp.StatusCode)
}

return nil
}

func parseContentRange(contentRange string) (start, end, size int64, err error) {
if contentRange == "" {
return -1, -1, -1, errors.New("unexpected empty content range")
}

matches := contentRangeRe.FindStringSubmatch(contentRange)
if len(matches) != 4 {
return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange)
}

if start, err = strconv.ParseInt(matches[1], 10, 64); err != nil {
return -1, -1, -1, fmt.Errorf("unexpected start from content range '%s', %v", contentRange, err)
}

if end, err = strconv.ParseInt(matches[2], 10, 64); err != nil {
return -1, -1, -1, fmt.Errorf("unexpected end from content range '%s', %v", contentRange, err)
}

if start > end {
return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange)
}

if matches[3] == "*" {
size = -1
} else {
size, err = strconv.ParseInt(matches[3], 10, 64)
if err != nil {
return -1, -1, -1, fmt.Errorf("unexpected total from content range '%s', %v", contentRange, err)
}

if end >= size {
return -1, -1, -1, fmt.Errorf("invalid content range: %s", contentRange)
}
}

return
}
Loading