-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Open
Labels
bugSomething isn't workingSomething isn't workingjax/flaxstaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
Following the blog post on Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e. This worked magically until I tried to generate an image in a different size. At 1024x1024 we get inference latency of ~3s per image (as compared to ~8s on the NVIDIA A10G). But change the resolution to 1280x960 and we see next to no improvement.
Reproduction
Use the same code as in the blog post: https://huggingface.co/blog/sdxl_jax
Changes:
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
width=1024,
height=1024,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=neg_prompt_ids,
guidance_scale=guidance_scale,
width=width,
height=height,
jit=True,
).images
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt, width=960, height=1280)
print(f"Compiled in {time.time() - start}")
start = time.time()
print("starting")
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt, width=960, height=1280)
print(f"Inference in {time.time() - start}")
Logs
No response
System Info
Python: 3.10.6
Diffusers: 0.26.2
Torch: 2.2.0+cu121
Jax: 0.4.23
Flax: 0.8.0
Who can help?
huseyintemiz
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingjax/flaxstaleIssues that haven't received updatesIssues that haven't received updates