1- from typing import List , Tuple
1+ from typing import Callable , List , Tuple
22
33import torch
44from diffusers .models .autoencoders .autoencoder_kl import AutoencoderKL
55from diffusers .schedulers .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
66
77from invokeai .app .invocations .bria_controlnet import BriaControlNetField
8- from invokeai .app .invocations .fields import Input , InputField , LatentsField , OutputField
8+ from invokeai .app .invocations .bria_latent_noise import BriaLatentNoiseOutput
9+ from invokeai .app .invocations .fields import FluxConditioningField , Input , InputField , LatentsField , OutputField
910from invokeai .app .invocations .model import SubModelType , T5EncoderField , TransformerField , VAEField
1011from invokeai .app .invocations .primitives import BaseInvocationOutput , FieldDescriptions
1112from invokeai .app .services .shared .invocation_context import InvocationContext
1213from invokeai .backend .bria .controlnet_bria import BriaControlModes , BriaMultiControlNetModel
1314from invokeai .backend .bria .controlnet_utils import prepare_control_images
1415from invokeai .backend .bria .pipeline_bria_controlnet import BriaControlNetPipeline
1516from invokeai .backend .bria .transformer_bria import BriaTransformer2DModel
17+ from invokeai .backend .model_manager .taxonomy import BaseModelType
18+ from invokeai .backend .stable_diffusion .extensions .preview import PipelineIntermediateState
1619from invokeai .invocation_api import BaseInvocation , Classification , invocation , invocation_output
1720
1821
@@ -30,6 +33,11 @@ class BriaDenoiseInvocationOutput(BaseInvocationOutput):
3033 classification = Classification .Prototype ,
3134)
3235class BriaDenoiseInvocation (BaseInvocation ):
36+
37+ """
38+ Denoise Bria latents using a Bria Pipeline.
39+ """
40+
3341 num_steps : int = InputField (
3442 default = 30 , title = "Number of Steps" , description = "The number of steps to use for the denoiser"
3543 )
@@ -52,31 +60,31 @@ class BriaDenoiseInvocation(BaseInvocation):
5260 input = Input .Connection ,
5361 title = "VAE" ,
5462 )
55- latents : LatentsField = InputField (
56- description = "Latents to denoise" ,
57- input = Input . Connection ,
58- title = "Latents " ,
63+ height : int = InputField (
64+ default = 1024 ,
65+ title = "Height" ,
66+ description = "The height of the output image " ,
5967 )
60- latent_image_ids : LatentsField = InputField (
61- description = "Latent Image IDs to denoise" ,
68+ width : int = InputField (
69+ default = 1024 ,
70+ title = "Width" ,
71+ description = "The width of the output image" ,
72+ )
73+ latent_noise : BriaLatentNoiseOutput = InputField (
74+ description = "Latent noise to denoise" ,
6275 input = Input .Connection ,
63- title = "Latent Image IDs " ,
76+ title = "Latent Noise " ,
6477 )
65- pos_embeds : LatentsField = InputField (
78+ pos_embeds : FluxConditioningField = InputField (
6679 description = "Positive Prompt Embeds" ,
6780 input = Input .Connection ,
6881 title = "Positive Prompt Embeds" ,
6982 )
70- neg_embeds : LatentsField = InputField (
83+ neg_embeds : FluxConditioningField = InputField (
7184 description = "Negative Prompt Embeds" ,
7285 input = Input .Connection ,
7386 title = "Negative Prompt Embeds" ,
7487 )
75- text_ids : LatentsField = InputField (
76- description = "Text IDs" ,
77- input = Input .Connection ,
78- title = "Text IDs" ,
79- )
8088 control : BriaControlNetField | list [BriaControlNetField ] | None = InputField (
8189 description = "ControlNet" ,
8290 input = Input .Connection ,
@@ -86,11 +94,10 @@ class BriaDenoiseInvocation(BaseInvocation):
8694
8795 @torch .no_grad ()
8896 def invoke (self , context : InvocationContext ) -> BriaDenoiseInvocationOutput :
89- latents = context .tensors .load (self .latents .latents_name )
90- pos_embeds = context .tensors .load (self .pos_embeds .latents_name )
91- neg_embeds = context .tensors .load (self .neg_embeds .latents_name )
92- text_ids = context .tensors .load (self .text_ids .latents_name )
93- latent_image_ids = context .tensors .load (self .latent_image_ids .latents_name )
97+ latents = context .tensors .load (self .latent_noise .latents .latents_name )
98+ pos_embeds = context .tensors .load (self .pos_embeds .conditioning_name )
99+ neg_embeds = context .tensors .load (self .neg_embeds .conditioning_name )
100+ latent_image_ids = context .tensors .load (self .latent_noise .latent_image_ids .latents_name )
94101 scheduler_identifier = self .transformer .transformer .model_copy (update = {"submodel_type" : SubModelType .Scheduler })
95102
96103 device = None
@@ -114,11 +121,12 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
114121 control_model , control_images , control_modes , control_scales = self ._prepare_multi_control (
115122 context = context ,
116123 vae = vae ,
117- width = 1024 ,
118- height = 1024 ,
124+ width = self . width ,
125+ height = self . height ,
119126 device = vae .device ,
120127 )
121128
129+
122130 pipeline = BriaControlNetPipeline (
123131 transformer = transformer ,
124132 scheduler = scheduler ,
@@ -129,31 +137,32 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
129137 )
130138 pipeline .to (device = transformer .device , dtype = transformer .dtype )
131139
132- latents = pipeline (
140+ output_latents = pipeline (
133141 control_image = control_images ,
134142 control_mode = control_modes ,
135- width = 1024 ,
136- height = 1024 ,
143+ width = self . width ,
144+ height = self . height ,
137145 controlnet_conditioning_scale = control_scales ,
138146 num_inference_steps = self .num_steps ,
139147 max_sequence_length = 128 ,
140148 guidance_scale = self .guidance_scale ,
141149 latents = latents ,
142150 latent_image_ids = latent_image_ids ,
143- text_ids = text_ids ,
144151 prompt_embeds = pos_embeds ,
145152 negative_prompt_embeds = neg_embeds ,
146153 output_type = "latent" ,
154+ step_callback = _build_step_callback (context ),
147155 )[0 ]
148156
149- assert isinstance (latents , torch .Tensor )
150- saved_input_latents_tensor = context .tensors .save (latents )
151- latents_output = LatentsField (latents_name = saved_input_latents_tensor )
152- return BriaDenoiseInvocationOutput (latents = latents_output )
157+
158+
159+ assert isinstance (output_latents , torch .Tensor )
160+ saved_input_latents_tensor = context .tensors .save (output_latents )
161+ return BriaDenoiseInvocationOutput (latents = LatentsField (latents_name = saved_input_latents_tensor ))
153162
154163 def _prepare_multi_control (
155164 self , context : InvocationContext , vae : AutoencoderKL , width : int , height : int , device : torch .device
156- ) -> Tuple [BriaMultiControlNetModel , List [torch .Tensor ], List [torch . Tensor ], List [float ]]:
165+ ) -> Tuple [BriaMultiControlNetModel , List [torch .Tensor ], List [int ], List [float ]]:
157166 control = self .control if isinstance (self .control , list ) else [self .control ]
158167 control_images , control_models , control_modes , control_scales = [], [], [], []
159168 for controlnet in control :
@@ -178,3 +187,11 @@ def _prepare_multi_control(
178187 device = device ,
179188 )
180189 return control_model , tensored_control_images , tensored_control_modes , control_scales
190+
191+
192+ def _build_step_callback (context : InvocationContext ) -> Callable [[PipelineIntermediateState ], None ]:
193+ def step_callback (state : PipelineIntermediateState ) -> None :
194+ return
195+ context .util .sd_step_callback (state , BaseModelType .Bria )
196+
197+ return step_callback
0 commit comments