@@ -251,7 +251,6 @@ def __call__(
251251 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
252252 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
253253 callback_on_step_end_tensor_inputs : Optional [List [str ]] = None ,
254- max_sequence_length : int = 128 ,
255254 step_callback : Callable [[PipelineIntermediateState ], None ] = None ,
256255 ):
257256 r"""
@@ -342,7 +341,6 @@ def __call__(
342341 prompt_embeds = prompt_embeds ,
343342 negative_prompt_embeds = negative_prompt_embeds ,
344343 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
345- max_sequence_length = max_sequence_length ,
346344 )
347345
348346 self ._guidance_scale = guidance_scale
@@ -416,15 +414,15 @@ def __call__(
416414 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
417415
418416 # Init Invoke step callback
419- # step_callback(
420- # PipelineIntermediateState(
421- # step=0,
422- # order=1,
423- # total_steps=num_inference_steps,
424- # timestep=int(timesteps[0]),
425- # latents=latents,
426- # ),
427- # )
417+ step_callback (
418+ PipelineIntermediateState (
419+ step = 0 ,
420+ order = 1 ,
421+ total_steps = num_inference_steps ,
422+ timestep = int (timesteps [0 ]),
423+ latents = latents . view ( 1 , 64 , 64 , 4 , 2 , 2 ). permute ( 0 , 3 , 1 , 4 , 2 , 5 ). reshape ( 1 , 4 , 128 , 128 ) ,
424+ ),
425+ )
428426
429427 # EYAL - added the CFG loop
430428 # 7. Denoising loop
@@ -513,15 +511,15 @@ def __call__(
513511 # call the callback, if provided
514512 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
515513 progress_bar .update ()
516- # step_callback(
517- # PipelineIntermediateState(
518- # step=i + 1,
519- # order=1,
520- # total_steps=num_inference_steps,
521- # timestep=int(t),
522- # latents=latents,
523- # ),
524- # )
514+ step_callback (
515+ PipelineIntermediateState (
516+ step = i + 1 ,
517+ order = 1 ,
518+ total_steps = num_inference_steps ,
519+ timestep = int (t ),
520+ latents = latents . view ( 1 , 64 , 64 , 4 , 2 , 2 ). permute ( 0 , 3 , 1 , 4 , 2 , 5 ). reshape ( 1 , 4 , 128 , 128 ) ,
521+ ),
522+ )
525523
526524 if output_type == "latent" :
527525 image = latents
0 commit comments