2424 AutoPipelineForInpainting ,
2525 AutoPipelineForText2Image ,
2626 DiffusionPipeline ,
27- FluxKontextPipeline ,
2827)
2928from diffusers .pipelines .stable_diffusion import StableDiffusionSafetyChecker
3029from diffusers .utils import load_image
@@ -485,7 +484,8 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
485484 if is_transformers_version (">=" , "4.40.0" ):
486485 SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
487486 SUPPORTED_ARCHITECTURES .append ("flux" )
488- SUPPORTED_ARCHITECTURES .append ("flux-kontext" )
487+ if is_diffusers_version (">=" , "0.35.0" ):
488+ SUPPORTED_ARCHITECTURES .append ("flux-kontext" )
489489
490490 AUTOMODEL_CLASS = AutoPipelineForImage2Image
491491 OVMODEL_CLASS = OVPipelineForImage2Image
@@ -502,8 +502,9 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
502502 if model_type in ["flux" , "stable-diffusion-3" , "flux-kontext" ]:
503503 inputs ["height" ] = height
504504 inputs ["width" ] = width
505-
506- inputs ["strength" ] = 0.75
505+
506+ if model_type != "flux-kontext" :
507+ inputs ["strength" ] = 0.75
507508
508509 return inputs
509510
@@ -535,7 +536,16 @@ def test_num_images_per_prompt(self, model_arch: str):
535536 height = height , width = width , batch_size = batch_size , model_type = model_arch
536537 )
537538 outputs = pipeline (** inputs , num_images_per_prompt = num_images_per_prompt ).images
538- self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , height , width , 3 ))
539+ if model_arch != "flux-kontext" :
540+ self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , height , width , 3 ))
541+ else :
542+ # output shape is fixed: https://github.com/huggingface/diffusers/blob/v0.35.1/src/diffusers/pipelines/flux/pipeline_flux_kontext.py#L882
543+ if (height == width ):
544+ self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , 1024 , 1024 , 3 ))
545+ elif (height > width ):
546+ self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , 1448 , 724 , 3 ))
547+ else :
548+ self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , 724 , 1448 , 3 ))
539549
540550 @parameterized .expand (["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ])
541551 @require_diffusers
@@ -568,8 +578,11 @@ def __call__(self, *args, **kwargs) -> None:
568578 @require_diffusers
569579 def test_shape (self , model_arch : str ):
570580 pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
571-
572- height , width , batch_size = 128 , 64 , 1
581+ if model_arch != "flux-kontext" :
582+ # output shape is fixed: https://github.com/huggingface/diffusers/blob/v0.35.1/src/diffusers/pipelines/flux/pipeline_flux_kontext.py#L882
583+ height , width , batch_size = 128 , 64 , 1
584+ else :
585+ height , width , batch_size = 1448 , 724 , 1
573586
574587 for input_type in ["pil" , "np" , "pt" ]:
575588 inputs = self .generate_inputs (
@@ -586,7 +599,7 @@ def test_shape(self, model_arch: str):
586599 elif output_type == "pt" :
587600 self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
588601 else :
589- if model_arch != "flux" and model_arch != "flux-kontext" :
602+ if not model_arch . startswith ( "flux" ) :
590603 out_channels = (
591604 pipeline .unet .config .out_channels
592605 if pipeline .unet is not None
@@ -611,9 +624,9 @@ def test_shape(self, model_arch: str):
611624 @require_diffusers
612625 def test_compare_to_diffusers_pipeline (self , model_arch : str ):
613626 height , width , batch_size = 128 , 128 , 1
614- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
615-
616- auto_cls = self . AUTOMODEL_CLASS if "flux-kontext" not in model_arch else FluxKontextPipeline
627+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
628+ auto_cls = self . AUTOMODEL_CLASS
629+
617630 diffusers_pipeline = auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
618631 ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
619632
0 commit comments