Skip to content

Commit 6246c70

Browse files
iczawsayakpaul
andauthored
[Community] PromptDiffusion Pipeline (#6752)
* Create promptdiffusioncontrolnet.py * Update __init__.py Added PromptDiffusionControlNetModel * Update __init__.py Added PromptDiffusionControlNetModel * Update promptdiffusioncontrolnet.py * Create pipeline_prompt_diffusion.py Added Prompt Diffusion pipeline. * Create convert_original_promptdiffusion_to_diffusers.py * Update convert_from_ckpt.py Added download_promptdiffusion_from_original_ckpt, convert_promptdiffusion_checkpoint * Update promptdiffusioncontrolnet.py * Update pipeline_prompt_diffusion.py * Update README.md * Update pipeline_prompt_diffusion.py * Delete src/diffusers/models/promptdiffusioncontrolnet.py * Update __init__.py * Update __init__.py * Delete scripts/convert_original_promptdiffusion_to_diffusers.py * Update convert_from_ckpt.py * Update README.md * Delete examples/community/pipeline_prompt_diffusion.py * Create README.md * Create promptdiffusioncontrolnet.py * Create convert_original_promptdiffusion_to_diffusers.py * Create pipeline_prompt_diffusion.py * Update README.md * Update pipeline_prompt_diffusion.py * Update README.md * Update pipeline_prompt_diffusion.py * Update convert_original_promptdiffusion_to_diffusers.py * Update promptdiffusioncontrolnet.py * Update README.md --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 577b8a2 commit 6246c70

File tree

4 files changed

+3874
-0
lines changed

4 files changed

+3874
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# PromptDiffusion Pipeline
2+
3+
From the project [page](https://zhendong-wang.github.io/prompt-diffusion.github.io/)
4+
5+
"With a prompt consisting of a task-specific example pair of images and text guidance, and a new query image, Prompt Diffusion can comprehend the desired task and generate the corresponding output image on both seen (trained) and unseen (new) task types."
6+
7+
For any usage questions, please refer to the [paper](https://arxiv.org/abs/2305.01115).
8+
9+
Prepare models by converting them from the [checkpoint](https://huggingface.co/zhendongw/prompt-diffusion)
10+
11+
To convert the controlnet, use cldm_v15.yaml from the [repository](https://github.com/Zhendong-Wang/Prompt-Diffusion/tree/main/models/):
12+
13+
```bash
14+
python convert_original_promptdiffusion_to_diffusers.py --checkpoint_path path-to-network-step04999.ckpt --original_config_file path-to-cldm_v15.yaml --dump_path path-to-output-directory
15+
```
16+
17+
To learn about how to convert the fine-tuned stable diffusion model, see the [Load different Stable Diffusion formats guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/other-formats).
18+
19+
20+
```py
21+
import torch
22+
from diffusers import UniPCMultistepScheduler
23+
from diffusers.utils import load_image
24+
from promptdiffusioncontrolnet import PromptDiffusionControlNetModel
25+
from pipeline_prompt_diffusion import PromptDiffusionPipeline
26+
27+
28+
from PIL import ImageOps
29+
30+
image_a = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true"))
31+
32+
image_b = load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true")
33+
query = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true"))
34+
35+
# load prompt diffusion controlnet and prompt diffusion
36+
37+
controlnet = PromptDiffusionControlNetModel.from_pretrained("iczaw/prompt-diffusion-diffusers", subfolder="controlnet", torch_dtype=torch.float16)
38+
model_id = "path-to-model"
39+
pipe = PromptDiffusionPipeline.from_pretrained("iczaw/prompt-diffusion-diffusers", subfolder="base", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16")
40+
41+
# speed up diffusion process with faster scheduler and memory optimization
42+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
43+
# remove following line if xformers is not installed
44+
pipe.enable_xformers_memory_efficient_attention()
45+
pipe.enable_model_cpu_offload()
46+
# generate image
47+
generator = torch.manual_seed(0)
48+
image = pipe("a tortoise", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query).images[0]
49+
50+
```

0 commit comments

Comments
 (0)