Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modules/model/QwenModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def encode_text(
#pad to 16 because attention processors and/or torch.compile can have issues with uneven sequence lengths, but only pad if an attention mask has to be used anyway:
#TODO the second condition could trigger https://github.com/pytorch/pytorch/issues/165506 again, but try like this because no attention mask
#is preferable: https://github.com/Nerogar/OneTrainer/pull/1109
if max_seq_length % 16 > 0 and (seq_lengths != max_seq_length).any():
max_seq_length += (16 - max_seq_length % 16)
if max_seq_length % 64 > 0 and (seq_lengths != max_seq_length).any():
max_seq_length += (64 - max_seq_length % 64)

text_encoder_output = text_encoder_output[:, :max_seq_length, :]
bool_attention_mask = tokens_mask[:, :max_seq_length].bool()
Expand Down
15 changes: 1 addition & 14 deletions modules/modelSampler/QwenSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def __sample_base(
if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()):
extra_step_kwargs["generator"] = generator #TODO purpose?

#txt_seq_lens = text_attention_mask.sum(dim=1).tolist()
txt_seq_lens = [text_attention_mask.shape[1]] * text_attention_mask.shape[0]

#FIXME list of lists is not according to type hint, but according to diffusers code
#https://github.com/huggingface/diffusers/issues/12295
img_shapes = [[(
Expand All @@ -110,25 +107,15 @@ def __sample_base(

self.model.transformer_to(self.train_device)

#FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294
image_seq_len = latent_image.shape[1]
image_attention_mask=torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=latent_image.device)
attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
attention_mask_2d = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]

for i, timestep in enumerate(tqdm(timesteps, desc="sampling")):
latent_model_input = torch.cat([latent_image] * batch_size)
expanded_timestep = timestep.expand(batch_size)
noise_pred = transformer(
hidden_states=latent_model_input.to(dtype=self.model.train_dtype.torch_dtype()),
timestep=expanded_timestep / 1000,
encoder_hidden_states=combined_prompt_embedding.to(dtype=self.model.train_dtype.torch_dtype()),
encoder_hidden_states_mask=text_attention_mask,
txt_seq_lens=txt_seq_lens,
encoder_hidden_states_mask=text_attention_mask if not torch.all(text_attention_mask) else None,
img_shapes=img_shapes,
attention_kwargs = {
"attention_mask": attention_mask_2d,
},
return_dict=True
).sample

Expand Down
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseChromaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def setup_optimizations(
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)

self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
model: ChromaModel,
Expand Down
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseFluxSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def setup_optimizations(
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)

self._set_attention_backend(model.transformer, config.attention_mechanism, mask=False)

def _setup_embeddings(
self,
model: FluxModel,
Expand Down
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseHiDreamSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def setup_optimizations(
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config)

self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
model: HiDreamModel,
Expand Down
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseHunyuanVideoSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def setup_optimizations(
quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config)

model.vae.enable_tiling()
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)


def _setup_embeddings(
self,
Expand Down
19 changes: 19 additions & 0 deletions modules/modelSetup/BaseModelSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from modules.model.BaseModel import BaseModel
from modules.util.config.TrainConfig import TrainConfig, TrainEmbeddingConfig, TrainModelPartConfig
from modules.util.enum.AttentionMechanism import AttentionMechanism
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.ModuleFilter import ModuleFilter
from modules.util.NamedParameterGroup import NamedParameterGroup, NamedParameterGroupCollection
Expand Down Expand Up @@ -235,3 +236,21 @@ def _setup_model_part_requires_grad(
if unique_name in self.frozen_parameters:
for param in self.frozen_parameters[unique_name]:
param.requires_grad_(False)

@staticmethod
def _set_attention_backend(component, attn: AttentionMechanism, mask: bool=False, varlen: bool=False):
match attn:
case AttentionMechanism.SDP:
component.set_attention_backend("native")
case AttentionMechanism.FLASH:
if mask or varlen:
print("Warning: FLASH attention might fail for this model, depending on other configuration (batch size > 1, etc.)")
component.set_attention_backend("flash")
case AttentionMechanism.SPLIT:
component.set_attention_backend("native_split")
case AttentionMechanism.FLASH_SPLIT:
component.set_attention_backend("flash_split")
if mask and not varlen:
print("Warning: FLASH attention might fail for this model, depending on other configuration (batch size > 1, etc.)")
case _:
raise NotImplementedError(f"attention mechanism {str(attn)} not implemented")
1 change: 1 addition & 0 deletions modules/modelSetup/BasePixArtAlphaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
Expand Down
17 changes: 2 additions & 15 deletions modules/modelSetup/BaseQwenSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, varlen=True)

def predict(
self,
Expand Down Expand Up @@ -131,11 +132,6 @@ def predict(
latent_input = scaled_noisy_latent_image
packed_latent_input = model.pack_latents(latent_input)

#FIXME this is the only case that the transformer accepts:
#see https://github.com/huggingface/diffusers/issues/12344
#actual text sequence lengths can be shorter,but they might be padded and masked
txt_seq_lens = [text_encoder_output.shape[1]] * text_encoder_output.shape[0]

#FIXME list of lists is not according to type hint, but according to diffusers code:
#https://github.com/huggingface/diffusers/issues/12295
img_shapes = [[(
Expand All @@ -144,21 +140,12 @@ def predict(
latent_input.shape[-1] // 2)
]] * latent_input.shape[0]

#FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294
image_attention_mask=torch.ones((packed_latent_input.shape[0], packed_latent_input.shape[1]), dtype=torch.bool, device=latent_image.device)
attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
attention_mask_2d = attention_mask[:, None, None, :] if not torch.all(text_attention_mask) else None

packed_predicted_flow = model.transformer(
hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()),
timestep=timestep / 1000,
encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
encoder_hidden_states_mask=text_attention_mask,
txt_seq_lens=txt_seq_lens,
encoder_hidden_states_mask=text_attention_mask if not torch.all(text_attention_mask) else None,
img_shapes=img_shapes,
attention_kwargs = {
"attention_mask": attention_mask_2d,
},
return_dict=True,
).sample

Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseSanaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseStableDiffusion3Setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseStableDiffusionSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.unet, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.unet, config.attention_mechanism)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseStableDiffusionXLSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config)
quantize_layers(model.vae, self.train_device, model.vae_train_dtype, config)
quantize_layers(model.unet, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.unet, config.attention_mechanism)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseZImageSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def predict(
self,
Expand Down
8 changes: 8 additions & 0 deletions modules/ui/TrainingTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modules.ui.SchedulerParamsWindow import SchedulerParamsWindow
from modules.ui.TimestepDistributionWindow import TimestepDistributionWindow
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.AttentionMechanism import AttentionMechanism
from modules.util.enum.DataType import DataType
from modules.util.enum.EMAMode import EMAMode
from modules.util.enum.GradientCheckpointingMethod import GradientCheckpointingMethod
Expand Down Expand Up @@ -336,6 +337,13 @@ def __create_base2_frame(self, master, row, video_training_enabled: bool = False
frame.grid_columnconfigure(0, weight=1)
row = 0

# attention mechanism
components.label(frame, row, 0, "Attention",
tooltip="The attention mechanism used during training. Use `SPLIT` on linux. On windows, use 'SDP' or `FLASH_SPLIT`. 'FLASH_SPLIT' can be faster but you have to install it, and it does not support all models. For very high batch sizes, the SPLIT variants might be slower.")
components.options(frame, row, 1, [str(x) for x in list(AttentionMechanism)], self.ui_state,
"attention_mechanism")
row += 1

# ema
components.label(frame, row, 0, "EMA",
tooltip="EMA averages the training progress over many steps, better preserving different concepts in big datasets")
Expand Down
3 changes: 3 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules.util.config.ConceptConfig import ConceptConfig
from modules.util.config.SampleConfig import SampleConfig
from modules.util.config.SecretsConfig import SecretsConfig
from modules.util.enum.AttentionMechanism import AttentionMechanism
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.ConfigPart import ConfigPart
from modules.util.enum.DataType import DataType
Expand Down Expand Up @@ -422,6 +423,7 @@ class TrainConfig(BaseConfig):
only_cache: bool
resolution: str
frames: str
attention_mechanism: AttentionMechanism
mse_strength: float
mae_strength: float
log_cosh_strength: float
Expand Down Expand Up @@ -1005,6 +1007,7 @@ def default_values() -> 'TrainConfig':
data.append(("only_cache", False, bool, False))
data.append(("resolution", "512", str, False))
data.append(("frames", "25", str, False))
data.append(("attention_mechanism", AttentionMechanism.SDP, AttentionMechanism, False))
data.append(("mse_strength", 1.0, float, False))
data.append(("mae_strength", 0.0, float, False))
data.append(("log_cosh_strength", 0.0, float, False))
Expand Down
11 changes: 11 additions & 0 deletions modules/util/enum/AttentionMechanism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from enum import Enum


class AttentionMechanism(Enum):
SDP = 'SDP'
FLASH = 'FLASH'
SPLIT = 'SPLIT'
FLASH_SPLIT = 'FLASH_SPLIT'

def __str__(self):
return self.value
2 changes: 1 addition & 1 deletion requirements-global.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pytorch-lightning==2.5.1.post0

# diffusion models
#Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup):
-e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers
-e git+https://github.com/dxqb/diffusers.git@split_attention#egg=diffusers
gguf==0.17.1
transformers==4.56.2
sentencepiece==0.2.1 # transitive dependency of transformers for tokenizer loading
Expand Down