Skip to content

Add in support for 2.0 #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ __pycache__
.*~
sd_*jpg
sd_*png
*.code-workspace
9 changes: 8 additions & 1 deletion examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ struct Args {
enum StableDiffusionVersion {
V1_5,
V2_1,
V2_0,
}

impl Args {
Expand All @@ -147,6 +148,7 @@ impl Args {
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(),
StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(),
StableDiffusionVersion::V2_0 => "data/clip_v2.0.safetensors".to_string(),
},
}
}
Expand All @@ -157,6 +159,7 @@ impl Args {
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(),
StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(),
StableDiffusionVersion::V2_0 => "data/vae_v2.0.safetensors".to_string(),
},
}
}
Expand All @@ -167,6 +170,7 @@ impl Args {
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(),
StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(),
StableDiffusionVersion::V2_0 => "data/unet_v2.0.safetensors".to_string(),
},
}
}
Expand Down Expand Up @@ -229,6 +233,9 @@ fn run(args: Args) -> anyhow::Result<()> {
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
StableDiffusionVersion::V2_0 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
};

let device_setup = diffusers::utils::DeviceSetup::new(cpu);
Expand All @@ -242,7 +249,7 @@ fn run(args: Args) -> anyhow::Result<()> {
let tokens = tokenizer.encode(&prompt)?;
let tokens: Vec<i64> = tokens.into_iter().map(|x| x as i64).collect();
let tokens = Tensor::from_slice(&tokens).view((1, -1)).to(clip_device);
let uncond_tokens = tokenizer.encode("")?;
let uncond_tokens = tokenizer.encode("words, borders, bad quality, ugly")?;
let uncond_tokens: Vec<i64> = uncond_tokens.into_iter().map(|x| x as i64).collect();
let uncond_tokens = Tensor::from_slice(&uncond_tokens).view((1, -1)).to(clip_device);

Expand Down
14 changes: 10 additions & 4 deletions scripts/get_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def ensure_data_dir(safetensors):

def get_safetensors(safetensors, weight_bits):
for name, url in safetensors.items():
print(f"Getting {name} {weight_bits} bit tensors...")
print(f"Getting {name} {weight_bits} bit tensors with {url}...")

# Download bin file
urllib.request.urlretrieve(url, os.path.join(data_path, f"{name}.bin"))
Expand Down Expand Up @@ -61,17 +61,23 @@ def get_urls(sd_version, weight_bits):
"pytorch_model": f"https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin"
}
safetensors_v2_1 = {
"vae_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/vae/diffusion_pytorch_model.bin",
"vae_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-/resolve/{branch}/vae/diffusion_pytorch_model.bin",
"unet_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/unet/diffusion_pytorch_model.bin",
"clip_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/text_encoder/pytorch_model.bin"
}
safetensors_v2_0 = {
"vae_v2.0": f"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/{branch}/vae/diffusion_pytorch_model.bin",
"unet_v2.0": f"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/{branch}/unet/diffusion_pytorch_model.bin",
"clip_v2.0": f"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/{branch}/text_encoder/pytorch_model.bin"
}

vocab_url = "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz"

return safetensors_v1_5 if sd_version == "1.5" else safetensors_v2_1, vocab_url
return safetensors_v1_5 if sd_version == "1.5" else (safetensors_v2_1 if sd_version == "2.1" else safetensors_v2_0), vocab_url

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download weights for diffusers-rs.")
parser.add_argument("--sd_version", "-v", choices=["2.1", "1.5"], default="2.1")
parser.add_argument("--sd_version", "-v", choices=["2.1", "1.5", "2.0"], default="2.0")
parser.add_argument("--weight_bits", "-w", choices=["16", "32"], default="16")
args = parser.parse_args()

Expand Down