Skip to content

Commit 4b64b56

Browse files
bhavya01yiyixuxugithub-actions[bot]
authored
Change timestep device to cpu for xla (#11501)
* Change timestep device to cpu for xla * Add all pipelines * ruff format * Apply style fixes --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 2bb640f commit 4b64b56

File tree

76 files changed

+444
-82
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+444
-82
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,13 @@ def __call__(
887887
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
888888

889889
# 4. Prepare timesteps
890-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
890+
if XLA_AVAILABLE:
891+
timestep_device = "cpu"
892+
else:
893+
timestep_device = device
894+
timesteps, num_inference_steps = retrieve_timesteps(
895+
self.scheduler, num_inference_steps, timestep_device, timesteps
896+
)
891897
self.scheduler.set_timesteps(num_inference_steps, device=device)
892898

893899
# 5. Prepare latents.

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -897,16 +897,20 @@ def __call__(
897897
dtype = self.dtype
898898

899899
# 3. Prepare timesteps
900+
if XLA_AVAILABLE:
901+
timestep_device = "cpu"
902+
else:
903+
timestep_device = device
900904
if not enforce_inference_steps:
901905
timesteps, num_inference_steps = retrieve_timesteps(
902-
self.scheduler, num_inference_steps, device, timesteps, sigmas
906+
self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
903907
)
904908
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
905909
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
906910
else:
907911
denoising_inference_steps = int(num_inference_steps / strength)
908912
timesteps, denoising_inference_steps = retrieve_timesteps(
909-
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
913+
self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas
910914
)
911915
timesteps = timesteps[-num_inference_steps:]
912916
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,16 +1100,20 @@ def __call__(
11001100
dtype = self.dtype
11011101

11021102
# 3. Prepare timesteps
1103+
if XLA_AVAILABLE:
1104+
timestep_device = "cpu"
1105+
else:
1106+
timestep_device = device
11031107
if not enforce_inference_steps:
11041108
timesteps, num_inference_steps = retrieve_timesteps(
1105-
self.scheduler, num_inference_steps, device, timesteps, sigmas
1109+
self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
11061110
)
11071111
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
11081112
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
11091113
else:
11101114
denoising_inference_steps = int(num_inference_steps / strength)
11111115
timesteps, denoising_inference_steps = retrieve_timesteps(
1112-
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
1116+
self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas
11131117
)
11141118
timesteps = timesteps[-num_inference_steps:]
11151119
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,13 @@ def __call__(
586586
# 4. Prepare timesteps
587587

588588
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
589-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
589+
if XLA_AVAILABLE:
590+
timestep_device = "cpu"
591+
else:
592+
timestep_device = device
593+
timesteps, num_inference_steps = retrieve_timesteps(
594+
self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
595+
)
590596

591597
# 5. Prepare latents.
592598
latent_channels = self.transformer.config.in_channels

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,13 @@ def __call__(
664664
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
665665

666666
# 4. Prepare timesteps
667-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
667+
if XLA_AVAILABLE:
668+
timestep_device = "cpu"
669+
else:
670+
timestep_device = device
671+
timesteps, num_inference_steps = retrieve_timesteps(
672+
self.scheduler, num_inference_steps, timestep_device, timesteps
673+
)
668674
self._num_timesteps = len(timesteps)
669675

670676
# 5. Prepare latents

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,13 @@ def __call__(
717717
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
718718

719719
# 4. Prepare timesteps
720-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
720+
if XLA_AVAILABLE:
721+
timestep_device = "cpu"
722+
else:
723+
timestep_device = device
724+
timesteps, num_inference_steps = retrieve_timesteps(
725+
self.scheduler, num_inference_steps, timestep_device, timesteps
726+
)
721727
self._num_timesteps = len(timesteps)
722728

723729
# 5. Prepare latents

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,13 @@ def __call__(
762762
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
763763

764764
# 4. Prepare timesteps
765-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
765+
if XLA_AVAILABLE:
766+
timestep_device = "cpu"
767+
else:
768+
timestep_device = device
769+
timesteps, num_inference_steps = retrieve_timesteps(
770+
self.scheduler, num_inference_steps, timestep_device, timesteps
771+
)
766772
self._num_timesteps = len(timesteps)
767773

768774
# 5. Prepare latents

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,13 @@ def __call__(
737737
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
738738

739739
# 4. Prepare timesteps
740-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
740+
if XLA_AVAILABLE:
741+
timestep_device = "cpu"
742+
else:
743+
timestep_device = device
744+
timesteps, num_inference_steps = retrieve_timesteps(
745+
self.scheduler, num_inference_steps, timestep_device, timesteps
746+
)
741747
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
742748
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
743749
self._num_timesteps = len(timesteps)

src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,13 @@ def __call__(
566566
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
567567

568568
# 4. Prepare timesteps
569-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
569+
if XLA_AVAILABLE:
570+
timestep_device = "cpu"
571+
else:
572+
timestep_device = device
573+
timesteps, num_inference_steps = retrieve_timesteps(
574+
self.scheduler, num_inference_steps, timestep_device, timesteps
575+
)
570576
self._num_timesteps = len(timesteps)
571577

572578
# 5. Prepare latents.

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,12 @@ def __call__(
599599
self.scheduler.config.get("base_shift", 0.25),
600600
self.scheduler.config.get("max_shift", 0.75),
601601
)
602+
if XLA_AVAILABLE:
603+
timestep_device = "cpu"
604+
else:
605+
timestep_device = device
602606
timesteps, num_inference_steps = retrieve_timesteps(
603-
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
607+
self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu
604608
)
605609
self._num_timesteps = len(timesteps)
606610

0 commit comments

Comments
 (0)