diff --git a/.gitignore b/.gitignore index ce223f9..ad52bad 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ __pycache__ .*~ sd_*jpg sd_*png +*.code-workspace \ No newline at end of file diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs index caeb483..3e2da86 100644 --- a/examples/stable-diffusion/main.rs +++ b/examples/stable-diffusion/main.rs @@ -138,6 +138,7 @@ struct Args { enum StableDiffusionVersion { V1_5, V2_1, + V2_0, } impl Args { @@ -147,6 +148,7 @@ impl Args { None => match self.sd_version { StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(), StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(), + StableDiffusionVersion::V2_0 => "data/clip_v2.0.safetensors".to_string(), }, } } @@ -157,6 +159,7 @@ impl Args { None => match self.sd_version { StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(), StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(), + StableDiffusionVersion::V2_0 => "data/vae_v2.0.safetensors".to_string(), }, } } @@ -167,6 +170,7 @@ impl Args { None => match self.sd_version { StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(), StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(), + StableDiffusionVersion::V2_0 => "data/unet_v2.0.safetensors".to_string(), }, } } @@ -229,6 +233,9 @@ fn run(args: Args) -> anyhow::Result<()> { StableDiffusionVersion::V2_1 => { stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) } + StableDiffusionVersion::V2_0 => { + stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) + } }; let device_setup = diffusers::utils::DeviceSetup::new(cpu); @@ -242,7 +249,7 @@ fn run(args: Args) -> anyhow::Result<()> { let tokens = tokenizer.encode(&prompt)?; let tokens: Vec = tokens.into_iter().map(|x| x as i64).collect(); let tokens = Tensor::from_slice(&tokens).view((1, -1)).to(clip_device); - let uncond_tokens = tokenizer.encode("")?; + let uncond_tokens = tokenizer.encode("words, borders, bad quality, ugly")?; let uncond_tokens: Vec = uncond_tokens.into_iter().map(|x| x as i64).collect(); let uncond_tokens = Tensor::from_slice(&uncond_tokens).view((1, -1)).to(clip_device); diff --git a/scripts/get_weights.py b/scripts/get_weights.py index 53a59c1..9163673 100644 --- a/scripts/get_weights.py +++ b/scripts/get_weights.py @@ -28,7 +28,7 @@ def ensure_data_dir(safetensors): def get_safetensors(safetensors, weight_bits): for name, url in safetensors.items(): - print(f"Getting {name} {weight_bits} bit tensors...") + print(f"Getting {name} {weight_bits} bit tensors with {url}...") # Download bin file urllib.request.urlretrieve(url, os.path.join(data_path, f"{name}.bin")) @@ -61,17 +61,23 @@ def get_urls(sd_version, weight_bits): "pytorch_model": f"https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin" } safetensors_v2_1 = { - "vae_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/vae/diffusion_pytorch_model.bin", + "vae_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-/resolve/{branch}/vae/diffusion_pytorch_model.bin", "unet_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/unet/diffusion_pytorch_model.bin", "clip_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/text_encoder/pytorch_model.bin" } + safetensors_v2_0 = { + "vae_v2.0": f"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/{branch}/vae/diffusion_pytorch_model.bin", + "unet_v2.0": f"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/{branch}/unet/diffusion_pytorch_model.bin", + "clip_v2.0": f"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/{branch}/text_encoder/pytorch_model.bin" + } + vocab_url = "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz" - return safetensors_v1_5 if sd_version == "1.5" else safetensors_v2_1, vocab_url + return safetensors_v1_5 if sd_version == "1.5" else (safetensors_v2_1 if sd_version == "2.1" else safetensors_v2_0), vocab_url if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download weights for diffusers-rs.") - parser.add_argument("--sd_version", "-v", choices=["2.1", "1.5"], default="2.1") + parser.add_argument("--sd_version", "-v", choices=["2.1", "1.5", "2.0"], default="2.0") parser.add_argument("--weight_bits", "-w", choices=["16", "32"], default="16") args = parser.parse_args()