diff --git a/pkg/bspatch/bspatch.go b/pkg/bspatch/bspatch.go index 3d32898..a3f5dc9 100644 --- a/pkg/bspatch/bspatch.go +++ b/pkg/bspatch/bspatch.go @@ -27,9 +27,10 @@ package bspatch import ( "bytes" + "encoding/binary" "fmt" "io" - "io/ioutil" + "os" "github.com/dsnet/compress/bzip2" "github.com/gabstv/go-bsdiff/pkg/util" @@ -42,11 +43,11 @@ func Bytes(oldfile, patch []byte) (newfile []byte, err error) { // Reader applies a BSDIFF4 patch (using oldbin and patchf) to create the newbin func Reader(oldbin io.Reader, newbin io.Writer, patchf io.Reader) error { - oldbs, err := ioutil.ReadAll(oldbin) + oldbs, err := io.ReadAll(oldbin) if err != nil { return err } - diffbytes, err := ioutil.ReadAll(patchf) + diffbytes, err := io.ReadAll(patchf) if err != nil { return err } @@ -59,11 +60,11 @@ func Reader(oldbin io.Reader, newbin io.Writer, patchf io.Reader) error { // File applies a BSDIFF4 patch (using oldfile and patchfile) to create the newfile func File(oldfile, newfile, patchfile string) error { - oldbs, err := ioutil.ReadFile(oldfile) + oldbs, err := os.ReadFile(oldfile) if err != nil { return fmt.Errorf("could not read oldfile '%v': %v", oldfile, err.Error()) } - patchbs, err := ioutil.ReadFile(patchfile) + patchbs, err := os.ReadFile(patchfile) if err != nil { return fmt.Errorf("could not read patchfile '%v': %v", patchfile, err.Error()) } @@ -71,7 +72,7 @@ func File(oldfile, newfile, patchfile string) error { if err != nil { return fmt.Errorf("bspatch: %v", err.Error()) } - if err := ioutil.WriteFile(newfile, newbytes, 0644); err != nil { + if err := os.WriteFile(newfile, newbytes, 0644); err != nil { return fmt.Errorf("could not create newfile '%v': %v", newfile, err.Error()) } return nil @@ -82,7 +83,6 @@ func patchb(oldfile, patch []byte) ([]byte, error) { var newsize int header := make([]byte, 32) buf := make([]byte, 8) - var lenread int var i int ctrl := make([]int, 3) @@ -131,6 +131,7 @@ func patchb(oldfile, patch []byte) ([]byte, error) { if err != nil { return nil, err } + defer cpfbz2.Close() dpf := bytes.NewReader(patch) if _, err := dpf.Seek(int64(32+bzctrllen), io.SeekStart); err != nil { return nil, err @@ -139,6 +140,7 @@ func patchb(oldfile, patch []byte) ([]byte, error) { if err != nil { return nil, err } + defer dpfbz2.Close() epf := bytes.NewReader(patch) if _, err := epf.Seek(int64(32+bzctrllen+bzdatalen), io.SeekStart); err != nil { return nil, err @@ -147,6 +149,7 @@ func patchb(oldfile, patch []byte) ([]byte, error) { if err != nil { return nil, err } + defer epfbz2.Close() pnew := make([]byte, newsize) @@ -156,13 +159,8 @@ func patchb(oldfile, patch []byte) ([]byte, error) { for newpos < newsize { // Read control data for i = 0; i <= 2; i++ { - lenread, err = zreadall(cpfbz2, buf, 8) - if lenread != 8 || (err != nil && err != io.EOF) { - e0 := "" - if err != nil { - e0 = err.Error() - } - return nil, fmt.Errorf("corrupt patch or bzstream ended: %s (read: %v/8)", e0, lenread) + if _, err = io.ReadFull(cpfbz2, buf); err != nil { + return nil, fmt.Errorf("corrupt patch or bzstream ended: %s", err) } ctrl[i] = offtin(buf) } @@ -172,14 +170,8 @@ func patchb(oldfile, patch []byte) ([]byte, error) { } // Read diff string - // lenread, err = dpfbz2.Read(pnew[newpos : newpos+ctrl[0]]) - lenread, err = zreadall(dpfbz2, pnew[newpos:newpos+ctrl[0]], ctrl[0]) - if lenread < ctrl[0] || (err != nil && err != io.EOF) { - e0 := "" - if err != nil { - e0 = err.Error() - } - return nil, fmt.Errorf("corrupt patch or bzstream ended (2): %s", e0) + if _, err := io.ReadFull(dpfbz2, pnew[newpos:newpos+ctrl[0]]); err != nil { + return nil, fmt.Errorf("corrupt patch or bzstream ended (2): %s", err) } // Add pold data to diff string for i = 0; i < ctrl[0]; i++ { @@ -198,81 +190,22 @@ func patchb(oldfile, patch []byte) ([]byte, error) { } // Read extra string - // epfbz2.Read was not reading all the requested bytes, probably an internal buffer limitation ? - // it was encapsulated by zreadall to work around the issue - lenread, err = zreadall(epfbz2, pnew[newpos:newpos+ctrl[1]], ctrl[1]) - if lenread < ctrl[1] || (err != nil && err != io.EOF) { - e0 := "" - if err != nil { - e0 = err.Error() - } - return nil, fmt.Errorf("corrupt patch or bzstream ended (3): %s", e0) + if _, err := io.ReadFull(epfbz2, pnew[newpos:newpos+ctrl[1]]); err != nil { + return nil, fmt.Errorf("corrupt patch or bzstream ended (3): %s", err) } // Adjust pointers newpos += ctrl[1] oldpos += ctrl[2] } - // Clean up the bzip2 reads - if err = cpfbz2.Close(); err != nil { - return nil, err - } - if err = dpfbz2.Close(); err != nil { - return nil, err - } - if err = epfbz2.Close(); err != nil { - return nil, err - } - cpfbz2 = nil - dpfbz2 = nil - epfbz2 = nil - cpf = nil - dpf = nil - epf = nil - return pnew, nil } -// offtin reads an int64 (little endian) +// offtin reads an int64 (little endian using a sign-bit, not two's complement) func offtin(buf []byte) int { - - y := int(buf[7] & 0x7f) - y = y * 256 - y += int(buf[6]) - y = y * 256 - y += int(buf[5]) - y = y * 256 - y += int(buf[4]) - y = y * 256 - y += int(buf[3]) - y = y * 256 - y += int(buf[2]) - y = y * 256 - y += int(buf[1]) - y = y * 256 - y += int(buf[0]) - - if (buf[7] & 0x80) != 0 { - y = -y - } - return y -} - -func zreadall(r io.Reader, b []byte, expected int) (int, error) { - var allread int - var offset int - for { - nread, err := r.Read(b[offset:]) - if nread == expected { - return nread, err - } - if err != nil { - return allread + nread, err - } - allread += nread - if allread >= expected { - return allread, nil - } - offset += nread + v := binary.LittleEndian.Uint64(buf) + if v&(1<<63) != 0 { + return -int(v &^ (1 << 63)) } + return int(v) } diff --git a/pkg/bspatch/bspatch_test.go b/pkg/bspatch/bspatch_test.go index 4e9c7a1..bdca901 100644 --- a/pkg/bspatch/bspatch_test.go +++ b/pkg/bspatch/bspatch_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "io" "io/ioutil" "os" "testing" @@ -280,7 +281,7 @@ func TestZReadAll(t *testing.T) { rr := &lowcaprdr{ read: make([]byte, 1024), } - nr, err := zreadall(rr, buf, len(buf)) + nr, err := io.ReadFull(rr, buf) if err != nil { t.Fail() }