Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast generator for cog build #2108

Merged
merged 21 commits into from
Jan 15, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add copy weights step
8W9aG committed Jan 10, 2025
commit 39999486f8f6a116b8d3625520c82761d89ffc29
73 changes: 59 additions & 14 deletions pkg/dockerfile/fast_generator.go
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@ import (
"github.com/replicate/cog/pkg/weights"
)

const FUSE_RPC_WEIGHTS_PATH = "/srv/r8/fuse-rpc/weights"

type FastGenerator struct {
Config *config.Config
Dir string
@@ -74,11 +76,18 @@ func (g *FastGenerator) generate() (string, error) {
if err != nil {
return "", err
}

lines := []string{}
lines, err = g.generateMonobase(lines, tmpDir)
if err != nil {
return "", err
}

lines, err = g.copyWeights(lines)
if err != nil {
return "", err
}

return strings.Join(lines, "\n"), nil
}

@@ -106,36 +115,72 @@ func (g *FastGenerator) copyCog(tmpDir string) (string, error) {
}

func (g *FastGenerator) generateMonobase(lines []string, tmpDir string) ([]string, error) {
lines = append(lines, []string{
"FROM monobase:latest",
}...)

cogPath, err := g.copyCog(tmpDir)
if err != nil {
return nil, err
}

lines = append(lines, []string{
"ENV R8_COG_VERSION=\"file:///buildtmp/" + filepath.Base(cogPath) + "\"",
}...)

relativeTmpDir, err := filepath.Rel(g.Dir, tmpDir)
if err != nil {
return nil, err
}
skipCudaArg := "--skip-cuda"
cudaVersion := "12.4"
cudnnVersion := "9"
if g.Config.Build.GPU {
skipCudaArg = ""
cudaVersion = g.Config.Build.CUDA
cudnnVersion = g.Config.Build.CuDNN
cudaVersion := g.Config.Build.CUDA
cudnnVersion := g.Config.Build.CuDNN
lines = append(lines, []string{
"ENV R8_CUDA_VERSION=" + cudaVersion,
"ENV R8_CUDNN_VERSION=" + cudnnVersion,
"ENV R8_CUDA_PREFIX=https://monobase.replicate.delivery/cuda",
"ENV R8_CUDNN_PREFIX=https://monobase.replicate.delivery/cudnn",
}...)
}

lines = append(lines, []string{
"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
}...)

torchVersion, ok := g.Config.TorchVersion()
if ok {
lines = append(lines, []string{
"ENV R8_TORCH_VERSION=" + torchVersion,
}...)
}
torchVersion, err := g.Config.TorchVersion()

return append(lines, []string{
"RUN --mount=type=bind,source=\"" + relativeTmpDir + "\",target=/buildtmp /opt/r8/monobase/build.sh " + skipCudaArg + " --mini",
}...), nil
}

func (g *FastGenerator) copyWeights(lines []string) ([]string, error) {
weights, err := FindWeights(g.Dir)
if err != nil {
return nil, err
}

if len(weights) == 0 {
return lines, nil
}

commands := []string{}
for sha256, file := range weights {
rel_path, err := filepath.Rel(g.Dir, file)
if err != nil {
return nil, err
}
commands = append(commands, "cp /src/"+rel_path+" "+filepath.Join(FUSE_RPC_WEIGHTS_PATH, sha256))
}

return append(lines, []string{
"FROM monobase:latest",
"ENV R8_COG_VERSION=\"file:///buildtmp/" + filepath.Base(cogPath) + "\"",
"ENV R8_CUDA_VERSION=" + cudaVersion,
"ENV R8_CUDNN_VERSION=" + cudnnVersion,
"ENV R8_CUDA_PREFIX=https://monobase.replicate.delivery/cuda",
"ENV R8_CUDNN_PREFIX=https://monobase.replicate.delivery/cudnn",
"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
"ENV R8_TORCH_VERSION=" + torchVersion,
"RUN --mount=type=bind,source=\"" + relativeTmpDir + "\",target=/buildtmp /opt/r8/monobase/build.sh " + skipCudaArg + " --mini",
"RUN --mount=type=bind,ro,source=.,target=/src mkdir -p " + FUSE_RPC_WEIGHTS_PATH + " && " + strings.Join(commands, " && "),
}...), nil
}
6 changes: 5 additions & 1 deletion pkg/dockerfile/user_cache.go
Original file line number Diff line number Diff line change
@@ -2,18 +2,22 @@ package dockerfile

import (
"os"
"os/user"
"path"
"path/filepath"
)

func UserCache() (string, error) {
path, err := filepath.Abs("~/.cog/cache")
usr, err := user.Current()
if err != nil {
return "", err
}

path := filepath.Join(usr.HomeDir, ".cog/cache")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}

return path, nil
}

67 changes: 67 additions & 0 deletions pkg/dockerfile/weights.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package dockerfile

import (
"crypto/sha256"
"encoding/hex"
"io"
"os"
"path/filepath"
"slices"
)

var WEIGHT_FILE_EXCLUSIONS = []string{
".gif",
".ipynb",
".jpeg",
".jpg",
".log",
".mp4",
".png",
".svg",
".webp",
}
var WEIGHT_FILE_INCLUSIONS = []string{
".ckpt",
".h5",
".onnx",
".pb",
".pbtxt",
".pt",
".pth",
".safetensors",
".tflite",
}

const WEIGHT_FILE_SIZE_EXCLUSION = 1024 * 1024
const WEIGHT_FILE_SIZE_INCLUSION = 128 * 1024 * 1024

func FindWeights(folder string) (map[string]string, error) {
weights := make(map[string]string)

err := filepath.Walk(folder, func(path string, info os.FileInfo, err error) error {
ext := filepath.Ext(path)

if slices.Contains(WEIGHT_FILE_EXCLUSIONS, ext) || info.Size() <= WEIGHT_FILE_SIZE_EXCLUSION {
return nil
}

if slices.Contains(WEIGHT_FILE_INCLUSIONS, ext) || info.Size() >= WEIGHT_FILE_SIZE_INCLUSION {
hash := sha256.New()

file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()

if _, err := io.Copy(hash, file); err != nil {
return err
}

weights[hex.EncodeToString(hash.Sum(nil))] = path
}
return nil
})

return weights, err
}