Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dowloader): resume partial downloads #4537

Merged
merged 10 commits into from
Jan 9, 2025
71 changes: 56 additions & 15 deletions pkg/downloader/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package downloader

import (
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -204,6 +206,25 @@ func removePartialFile(tmpFilePath string) error {
return nil
}

func calculateHashForPartialFile(file *os.File) (hash.Hash, error) {
hash := sha256.New()
_, err := io.Copy(hash, file)
if err != nil {
return nil, err
}
return hash, nil
}

func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
url := uri.ResolveURL()
resp, err := http.Head(url)
if err != nil {
return false, err
}
defer resp.Body.Close()
return resp.Header.Get("Accept-Ranges") == "bytes", nil
}

func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
url := uri.ResolveURL()
if uri.LooksLikeOCI() {
Expand Down Expand Up @@ -266,8 +287,34 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat

log.Info().Msgf("Downloading %q", url)

// Download file
resp, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request for %q: %v", filePath, err)
}

// save partial download to dedicated file
tmpFilePath := filePath + ".partial"
tmpFileInfo, err := os.Stat(tmpFilePath)
if err == nil {
support, err := uri.checkSeverSupportsRangeHeader()
if err != nil {
return fmt.Errorf("failed to check if uri server supports range header: %v", err)
}
if support {
startPos := tmpFileInfo.Size()
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos))
} else {
err := removePartialFile(tmpFilePath)
if err != nil {
return err
}
}
} else if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to check file %q existence: %v", filePath, err)
}

// Start the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download file %q: %v", filePath, err)
}
Expand All @@ -283,26 +330,20 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err)
}

// save partial download to dedicated file
tmpFilePath := filePath + ".partial"

// remove tmp file
err = removePartialFile(tmpFilePath)
// Create and write file
outFile, err := os.OpenFile(tmpFilePath, os.O_APPEND|os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
return err
return fmt.Errorf("failed to create / open file %q: %v", tmpFilePath, err)
}

// Create and write file content
outFile, err := os.Create(tmpFilePath)
defer outFile.Close()
hash, err := calculateHashForPartialFile(outFile)
if err != nil {
return fmt.Errorf("failed to create file %q: %v", tmpFilePath, err)
return fmt.Errorf("failed to calculate hash for partial file")
}
defer outFile.Close()

progress := &progressWriter{
fileName: tmpFilePath,
total: resp.ContentLength,
hash: sha256.New(),
hash: hash,
fileNo: fileN,
totalFiles: total,
downloadStatus: downloadStatus,
Expand Down
145 changes: 145 additions & 0 deletions pkg/downloader/uri_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
package downloader_test

import (
"crypto/rand"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strconv"

. "github.com/mudler/LocalAI/pkg/downloader"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -38,3 +47,139 @@ var _ = Describe("Gallery API tests", func() {
})
})
})

type RangeHeaderError struct {
msg string
}

func (e *RangeHeaderError) Error() string { return e.msg }

var _ = Describe("Download Test", func() {
var mockData []byte
var mockDataSha string
var filePath string

extractRangeHeader := func(rangeString string) (int, int, error) {
regex := regexp.MustCompile(`^bytes=(\d+)-(\d+|)$`)
matches := regex.FindStringSubmatch(rangeString)
rangeErr := RangeHeaderError{msg: "invalid / ill-formatted range"}
if matches == nil {
return -1, -1, &rangeErr
}
startPos, err := strconv.Atoi(matches[1])
if err != nil {
return -1, -1, err
}

endPos := -1
if matches[2] != "" {
endPos, err = strconv.Atoi(matches[2])
if err != nil {
return -1, -1, err
}
endPos += 1 // because range is inclusive in rangeString
}
return startPos, endPos, nil
}

getMockServer := func(supportsRangeHeader bool) *httptest.Server {
mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "HEAD" && r.Method != "GET" {
w.WriteHeader(http.StatusNotFound)
return
}
if r.Method == "HEAD" {
if supportsRangeHeader {
w.Header().Add("Accept-Ranges", "bytes")
}
w.WriteHeader(http.StatusOK)
return
}
// GET method
startPos := 0
endPos := len(mockData)
var err error
var respData []byte
rangeString := r.Header.Get("Range")
if rangeString != "" {
startPos, endPos, err = extractRangeHeader(rangeString)
if err != nil {
if _, ok := err.(*RangeHeaderError); ok {
w.WriteHeader(http.StatusBadRequest)
return
}
Expect(err).ToNot(HaveOccurred())
}
if endPos == -1 {
endPos = len(mockData)
}
if startPos < 0 || startPos >= len(mockData) || endPos < 0 || endPos > len(mockData) || startPos > endPos {
w.WriteHeader(http.StatusBadRequest)
return
}
}
respData = mockData[startPos:endPos]
w.WriteHeader(http.StatusOK)
w.Write(respData)
}))
mockServer.EnableHTTP2 = true
mockServer.Start()
return mockServer
}

BeforeEach(func() {
mockData = make([]byte, 20000)
_, err := rand.Read(mockData)
Expect(err).ToNot(HaveOccurred())
_mockDataSha := sha256.New()
_, err = _mockDataSha.Write(mockData)
Expect(err).ToNot(HaveOccurred())
mockDataSha = fmt.Sprintf("%x", _mockDataSha.Sum(nil))
dir, err := os.Getwd()
filePath = dir + "/my_supercool_model"
Expect(err).NotTo(HaveOccurred())
})

Context("URI DownloadFile", func() {
It("fetches files from mock server", func() {
mockServer := getMockServer(true)
defer mockServer.Close()
uri := URI(mockServer.URL)
err := uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred())
})

It("resumes partially downloaded files", func() {
mockServer := getMockServer(true)
defer mockServer.Close()
uri := URI(mockServer.URL)
// Create a partial file
tmpFilePath := filePath + ".partial"
file, err := os.OpenFile(tmpFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
Expect(err).ToNot(HaveOccurred())
_, err = file.Write(mockData[0:10000])
Expect(err).ToNot(HaveOccurred())
err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred())
})

It("restarts download from 0 if server doesn't support Range header", func() {
mockServer := getMockServer(false)
defer mockServer.Close()
uri := URI(mockServer.URL)
// Create a partial file
tmpFilePath := filePath + ".partial"
file, err := os.OpenFile(tmpFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
Expect(err).ToNot(HaveOccurred())
_, err = file.Write(mockData[0:10000])
Expect(err).ToNot(HaveOccurred())
err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {})
Expect(err).ToNot(HaveOccurred())
})
})

AfterEach(func() {
os.Remove(filePath) // cleanup, also checks existance of filePath`
os.Remove(filePath + ".partial")
})
})
Loading