diff --git a/modules/model/QwenModel.py b/modules/model/QwenModel.py index afa6c24fe..71de80320 100644 --- a/modules/model/QwenModel.py +++ b/modules/model/QwenModel.py @@ -177,8 +177,8 @@ def encode_text( #pad to 16 because attention processors and/or torch.compile can have issues with uneven sequence lengths, but only pad if an attention mask has to be used anyway: #TODO the second condition could trigger https://github.com/pytorch/pytorch/issues/165506 again, but try like this because no attention mask #is preferable: https://github.com/Nerogar/OneTrainer/pull/1109 - if max_seq_length % 16 > 0 and (seq_lengths != max_seq_length).any(): - max_seq_length += (16 - max_seq_length % 16) + if max_seq_length % 64 > 0 and (seq_lengths != max_seq_length).any(): + max_seq_length += (64 - max_seq_length % 64) text_encoder_output = text_encoder_output[:, :max_seq_length, :] bool_attention_mask = tokens_mask[:, :max_seq_length].bool() diff --git a/modules/modelSampler/QwenSampler.py b/modules/modelSampler/QwenSampler.py index 798bc54cd..0fc27c7a0 100644 --- a/modules/modelSampler/QwenSampler.py +++ b/modules/modelSampler/QwenSampler.py @@ -97,9 +97,6 @@ def __sample_base( if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()): extra_step_kwargs["generator"] = generator #TODO purpose? - #txt_seq_lens = text_attention_mask.sum(dim=1).tolist() - txt_seq_lens = [text_attention_mask.shape[1]] * text_attention_mask.shape[0] - #FIXME list of lists is not according to type hint, but according to diffusers code #https://github.com/huggingface/diffusers/issues/12295 img_shapes = [[( @@ -110,12 +107,6 @@ def __sample_base( self.model.transformer_to(self.train_device) - #FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294 - image_seq_len = latent_image.shape[1] - image_attention_mask=torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=latent_image.device) - attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) - attention_mask_2d = attention_mask[:, None, None, :] * attention_mask[:, None, :, None] - for i, timestep in enumerate(tqdm(timesteps, desc="sampling")): latent_model_input = torch.cat([latent_image] * batch_size) expanded_timestep = timestep.expand(batch_size) @@ -123,12 +114,8 @@ def __sample_base( hidden_states=latent_model_input.to(dtype=self.model.train_dtype.torch_dtype()), timestep=expanded_timestep / 1000, encoder_hidden_states=combined_prompt_embedding.to(dtype=self.model.train_dtype.torch_dtype()), - encoder_hidden_states_mask=text_attention_mask, - txt_seq_lens=txt_seq_lens, + encoder_hidden_states_mask=text_attention_mask if not torch.all(text_attention_mask) else None, img_shapes=img_shapes, - attention_kwargs = { - "attention_mask": attention_mask_2d, - }, return_dict=True ).sample diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 7a7847df7..25ed8d155 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -86,6 +86,8 @@ def setup_optimizations( quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True) + def _setup_embeddings( self, model: ChromaModel, diff --git a/modules/modelSetup/BaseFluxSetup.py b/modules/modelSetup/BaseFluxSetup.py index 1865f382e..621837d17 100644 --- a/modules/modelSetup/BaseFluxSetup.py +++ b/modules/modelSetup/BaseFluxSetup.py @@ -89,6 +89,8 @@ def setup_optimizations( quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism, mask=False) + def _setup_embeddings( self, model: FluxModel, diff --git a/modules/modelSetup/BaseHiDreamSetup.py b/modules/modelSetup/BaseHiDreamSetup.py index 48e691ecd..b9e646272 100644 --- a/modules/modelSetup/BaseHiDreamSetup.py +++ b/modules/modelSetup/BaseHiDreamSetup.py @@ -105,6 +105,8 @@ def setup_optimizations( quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True) + def _setup_embeddings( self, model: HiDreamModel, diff --git a/modules/modelSetup/BaseHunyuanVideoSetup.py b/modules/modelSetup/BaseHunyuanVideoSetup.py index bbb90b71a..ddec0b122 100644 --- a/modules/modelSetup/BaseHunyuanVideoSetup.py +++ b/modules/modelSetup/BaseHunyuanVideoSetup.py @@ -91,6 +91,8 @@ def setup_optimizations( quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config) model.vae.enable_tiling() + self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True) + def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseModelSetup.py b/modules/modelSetup/BaseModelSetup.py index 5b0e7df95..a4f8384c4 100644 --- a/modules/modelSetup/BaseModelSetup.py +++ b/modules/modelSetup/BaseModelSetup.py @@ -3,6 +3,7 @@ from modules.model.BaseModel import BaseModel from modules.util.config.TrainConfig import TrainConfig, TrainEmbeddingConfig, TrainModelPartConfig +from modules.util.enum.AttentionMechanism import AttentionMechanism from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroup, NamedParameterGroupCollection @@ -235,3 +236,21 @@ def _setup_model_part_requires_grad( if unique_name in self.frozen_parameters: for param in self.frozen_parameters[unique_name]: param.requires_grad_(False) + + @staticmethod + def _set_attention_backend(component, attn: AttentionMechanism, mask: bool=False, varlen: bool=False): + match attn: + case AttentionMechanism.SDP: + component.set_attention_backend("native") + case AttentionMechanism.FLASH: + if mask or varlen: + print("Warning: FLASH attention might fail for this model, depending on other configuration (batch size > 1, etc.)") + component.set_attention_backend("flash") + case AttentionMechanism.SPLIT: + component.set_attention_backend("native_split") + case AttentionMechanism.FLASH_SPLIT: + component.set_attention_backend("flash_split") + if mask and not varlen: + print("Warning: FLASH attention might fail for this model, depending on other configuration (batch size > 1, etc.)") + case _: + raise NotImplementedError(f"attention mechanism {str(attn)} not implemented") diff --git a/modules/modelSetup/BasePixArtAlphaSetup.py b/modules/modelSetup/BasePixArtAlphaSetup.py index 8240fb5f4..38993344e 100644 --- a/modules/modelSetup/BasePixArtAlphaSetup.py +++ b/modules/modelSetup/BasePixArtAlphaSetup.py @@ -85,6 +85,7 @@ def setup_optimizations( quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index dc7115274..2d6ad58db 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -80,6 +80,7 @@ def setup_optimizations( quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism, varlen=True) def predict( self, @@ -131,11 +132,6 @@ def predict( latent_input = scaled_noisy_latent_image packed_latent_input = model.pack_latents(latent_input) - #FIXME this is the only case that the transformer accepts: - #see https://github.com/huggingface/diffusers/issues/12344 - #actual text sequence lengths can be shorter,but they might be padded and masked - txt_seq_lens = [text_encoder_output.shape[1]] * text_encoder_output.shape[0] - #FIXME list of lists is not according to type hint, but according to diffusers code: #https://github.com/huggingface/diffusers/issues/12295 img_shapes = [[( @@ -144,21 +140,12 @@ def predict( latent_input.shape[-1] // 2) ]] * latent_input.shape[0] - #FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294 - image_attention_mask=torch.ones((packed_latent_input.shape[0], packed_latent_input.shape[1]), dtype=torch.bool, device=latent_image.device) - attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) - attention_mask_2d = attention_mask[:, None, None, :] if not torch.all(text_attention_mask) else None - packed_predicted_flow = model.transformer( hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), timestep=timestep / 1000, encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()), - encoder_hidden_states_mask=text_attention_mask, - txt_seq_lens=txt_seq_lens, + encoder_hidden_states_mask=text_attention_mask if not torch.all(text_attention_mask) else None, img_shapes=img_shapes, - attention_kwargs = { - "attention_mask": attention_mask_2d, - }, return_dict=True, ).sample diff --git a/modules/modelSetup/BaseSanaSetup.py b/modules/modelSetup/BaseSanaSetup.py index 84078ff6f..a9b49bd11 100644 --- a/modules/modelSetup/BaseSanaSetup.py +++ b/modules/modelSetup/BaseSanaSetup.py @@ -97,6 +97,7 @@ def setup_optimizations( quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseStableDiffusion3Setup.py b/modules/modelSetup/BaseStableDiffusion3Setup.py index 5015b21af..af07860c3 100644 --- a/modules/modelSetup/BaseStableDiffusion3Setup.py +++ b/modules/modelSetup/BaseStableDiffusion3Setup.py @@ -91,6 +91,7 @@ def setup_optimizations( quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype, config) quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseStableDiffusionSetup.py b/modules/modelSetup/BaseStableDiffusionSetup.py index 0fc6ed0df..020081fb5 100644 --- a/modules/modelSetup/BaseStableDiffusionSetup.py +++ b/modules/modelSetup/BaseStableDiffusionSetup.py @@ -71,6 +71,7 @@ def setup_optimizations( quantize_layers(model.text_encoder, self.train_device, model.train_dtype, config) quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.unet, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.unet, config.attention_mechanism) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseStableDiffusionXLSetup.py b/modules/modelSetup/BaseStableDiffusionXLSetup.py index 37121951a..eb842a884 100644 --- a/modules/modelSetup/BaseStableDiffusionXLSetup.py +++ b/modules/modelSetup/BaseStableDiffusionXLSetup.py @@ -80,6 +80,7 @@ def setup_optimizations( quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config) quantize_layers(model.vae, self.train_device, model.vae_train_dtype, config) quantize_layers(model.unet, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.unet, config.attention_mechanism) def _setup_embeddings( self, diff --git a/modules/modelSetup/BaseZImageSetup.py b/modules/modelSetup/BaseZImageSetup.py index 727122f8b..d5f4628df 100644 --- a/modules/modelSetup/BaseZImageSetup.py +++ b/modules/modelSetup/BaseZImageSetup.py @@ -81,6 +81,7 @@ def setup_optimizations( quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) quantize_layers(model.vae, self.train_device, model.train_dtype, config) quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True) def predict( self, diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index 9cc2bcec8..49db1a363 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -15,6 +15,7 @@ from modules.ui.SchedulerParamsWindow import SchedulerParamsWindow from modules.ui.TimestepDistributionWindow import TimestepDistributionWindow from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.AttentionMechanism import AttentionMechanism from modules.util.enum.DataType import DataType from modules.util.enum.EMAMode import EMAMode from modules.util.enum.GradientCheckpointingMethod import GradientCheckpointingMethod @@ -336,6 +337,13 @@ def __create_base2_frame(self, master, row, video_training_enabled: bool = False frame.grid_columnconfigure(0, weight=1) row = 0 + # attention mechanism + components.label(frame, row, 0, "Attention", + tooltip="The attention mechanism used during training. Use `SPLIT` on linux. On windows, use 'SDP' or `FLASH_SPLIT`. 'FLASH_SPLIT' can be faster but you have to install it, and it does not support all models. For very high batch sizes, the SPLIT variants might be slower.") + components.options(frame, row, 1, [str(x) for x in list(AttentionMechanism)], self.ui_state, + "attention_mechanism") + row += 1 + # ema components.label(frame, row, 0, "EMA", tooltip="EMA averages the training progress over many steps, better preserving different concepts in big datasets") diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index ddaee4b89..1873008c8 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -9,6 +9,7 @@ from modules.util.config.ConceptConfig import ConceptConfig from modules.util.config.SampleConfig import SampleConfig from modules.util.config.SecretsConfig import SecretsConfig +from modules.util.enum.AttentionMechanism import AttentionMechanism from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.ConfigPart import ConfigPart from modules.util.enum.DataType import DataType @@ -422,6 +423,7 @@ class TrainConfig(BaseConfig): only_cache: bool resolution: str frames: str + attention_mechanism: AttentionMechanism mse_strength: float mae_strength: float log_cosh_strength: float @@ -1005,6 +1007,7 @@ def default_values() -> 'TrainConfig': data.append(("only_cache", False, bool, False)) data.append(("resolution", "512", str, False)) data.append(("frames", "25", str, False)) + data.append(("attention_mechanism", AttentionMechanism.SDP, AttentionMechanism, False)) data.append(("mse_strength", 1.0, float, False)) data.append(("mae_strength", 0.0, float, False)) data.append(("log_cosh_strength", 0.0, float, False)) diff --git a/modules/util/enum/AttentionMechanism.py b/modules/util/enum/AttentionMechanism.py new file mode 100644 index 000000000..a5e2d5e96 --- /dev/null +++ b/modules/util/enum/AttentionMechanism.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class AttentionMechanism(Enum): + SDP = 'SDP' + FLASH = 'FLASH' + SPLIT = 'SPLIT' + FLASH_SPLIT = 'FLASH_SPLIT' + + def __str__(self): + return self.value diff --git a/requirements-global.txt b/requirements-global.txt index afa429bf2..829ebedd0 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -20,7 +20,7 @@ pytorch-lightning==2.5.1.post0 # diffusion models #Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup): --e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers +-e git+https://github.com/dxqb/diffusers.git@split_attention#egg=diffusers gguf==0.17.1 transformers==4.56.2 sentencepiece==0.2.1 # transitive dependency of transformers for tokenizer loading