Skip to content

Commit 8209ffb

Browse files
authored
Merge branch 'develop' into add_audio_to_output
2 parents 23a0289 + 0241dc6 commit 8209ffb

File tree

18 files changed

+344
-107
lines changed

18 files changed

+344
-107
lines changed
Lines changed: 3 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,4 @@
1-
import sys
1+
from shared import LoraLoader
22

3-
if sys.version_info >= (3, 11):
4-
from enum import StrEnum
5-
# StrEnum is introduced in 3.11 while we support python 3.10
6-
else:
7-
from enum import Enum, auto
8-
from typing import Any
9-
10-
# Fallback for Python 3.10 and earlier
11-
class StrEnum(str, Enum):
12-
def __new__(cls, value, *args, **kwargs):
13-
if not isinstance(value, (str, auto)):
14-
raise TypeError(
15-
f"Values of StrEnums must be strings: {value!r} is a {type(value)}"
16-
)
17-
return super().__new__(cls, value, *args, **kwargs)
18-
19-
def __str__(self):
20-
return str(self.value)
21-
22-
@staticmethod
23-
def _generate_next_value_(
24-
name: str, start: int, count: int, last_values: list[Any]
25-
) -> str:
26-
return name
27-
28-
29-
class LoraLoader(StrEnum):
30-
DIFFUSERS = "diffusers"
31-
LORA_READY = "lora_ready"
32-
DEFAULT = LORA_READY
33-
34-
@staticmethod
35-
def supported_values() -> list[str]:
36-
"""Returns a list of all supported LoraLoader values."""
37-
return [loader.value for loader in LoraLoader]
38-
39-
@staticmethod
40-
def safe_parse(value: "str | LoraLoader") -> "LoraLoader":
41-
if isinstance(value, LoraLoader):
42-
return value
43-
try:
44-
return LoraLoader(value)
45-
except ValueError:
46-
return LoraLoader.DEFAULT
47-
48-
49-
if __name__ == "__main__":
50-
# Test the StrEnum functionality
51-
print("diffusers:", LoraLoader.DIFFUSERS) # Should print "diffusers"
52-
print("lora_ready:", LoraLoader.LORA_READY) # Should print "lora_ready"
53-
print("default:", LoraLoader.DEFAULT) # Should print "lora_ready"
54-
print( # Should print all unique supported values (excludes aliases like DEFAULT)
55-
"supported_values:", LoraLoader.supported_values()
56-
)
57-
try:
58-
print("fail:", LoraLoader("invalid")) # Should raise ValueError
59-
except ValueError as e:
60-
print("pass:", e) # Prints: Invalid LoraLoader value: invalid
61-
try:
62-
print("pass:", LoraLoader("diffusers")) # Should return LoraLoader.DIFFUSERS
63-
except ValueError as e:
64-
print("fail:", e)
65-
try:
66-
print("type of LoraLoader.DEFAULT:", type(LoraLoader.DEFAULT))
67-
default = LoraLoader.DEFAULT
68-
print("type of default:", type(default)) # Should be LoraLoader, not str
69-
except Exception as e:
70-
print(f"fail: {e}")
71-
72-
assert isinstance(LoraLoader("lora_ready"), StrEnum)
73-
assert isinstance(LoraLoader.DIFFUSERS, LoraLoader), (
74-
"DIFFUSERS should be an instance of LoraLoader"
75-
)
76-
assert LoraLoader.DEFAULT == LoraLoader.DIFFUSERS, (
77-
"Default loader should be DIFFUSERS"
78-
)
79-
assert LoraLoader.DIFFUSERS != LoraLoader.LORA_READY, (
80-
"DIFFUSERS should not equal LORA_READY"
81-
)
82-
83-
assert LoraLoader.LORA_READY.value == "lora_ready", (
84-
"lora_ready string should equal LoraLoader.LORA_READY"
85-
)
3+
# todo: remove this import when the diffusers_helper is updated to use the new enums directly
4+
__all__ = ["LoraLoader"]

modules/generators/base_generator.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from diffusers_helper import lora_utils
66
from typing import List, Optional, cast
77
from pathlib import Path
8+
from transformers import BitsAndBytesConfig
89

910
from diffusers_helper.lora_utils_kohya_ss.enums import LoraLoader
1011
from diffusers_helper.models.hunyuan_video_packed import (
@@ -13,6 +14,7 @@
1314

1415
from ..settings import Settings
1516
from .model_configuration import ModelConfiguration
17+
from shared import QuantizationFormat
1618

1719
# cSpell: ignore loras
1820

@@ -23,6 +25,9 @@ class BaseModelGenerator(ABC):
2325
This defines the common interface that all model generators must implement.
2426
"""
2527

28+
quantization_format: QuantizationFormat = QuantizationFormat.DEFAULT
29+
quantization_config: BitsAndBytesConfig | None = None
30+
2631
def __init__(
2732
self,
2833
text_encoder,
@@ -72,9 +77,32 @@ def __init__(
7277
self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7378
self.cpu = torch.device("cpu")
7479

80+
# quantization is currently global, configured in settings
81+
# maybe add kwargs if we need this to be more dynamic per job?
82+
self.quantization_format = self.settings.get(
83+
"quantization_format", QuantizationFormat.integer_8bit
84+
)
85+
self.set_quantization_config()
86+
7587
self.previous_model_hash: str = ""
7688
self.previous_model_configuration: ModelConfiguration | None = None
7789

90+
def set_quantization_config(self):
91+
if self.quantization_format == QuantizationFormat.brain_floating_point_16bit:
92+
# BF16 does not require a special config
93+
pass
94+
if self.quantization_format == QuantizationFormat.normal_float_4bit:
95+
# 4-bit NF4 quantization config
96+
self.quantization_config = BitsAndBytesConfig(
97+
load_in_4bit=True,
98+
bnb_4bit_compute_dtype="bfloat16",
99+
bnb_4bit_quant_type="nf4",
100+
bnb_4bit_use_double_quant=True,
101+
)
102+
if self.quantization_format == QuantizationFormat.integer_8bit:
103+
# 8-bit integer quantization config
104+
self.quantization_config = BitsAndBytesConfig(load_in_8bit=True)
105+
78106
@abstractmethod
79107
def load_model(self) -> HunyuanVideoTransformer3DModelPacked:
80108
"""
@@ -389,6 +417,7 @@ def load_loras(
389417

390418
active_model_configuration = ModelConfiguration.from_lora_names_and_weights(
391419
self.get_model_name(),
420+
self.quantization_format,
392421
selected_loras,
393422
selected_lora_values,
394423
self.settings.lora_loader,

modules/generators/f1_generator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
from diffusers_helper.memory import DynamicSwapInstaller
66
from .base_generator import BaseModelGenerator
7+
from shared import QuantizationFormat, timer
78

89

910
class F1ModelGenerator(BaseModelGenerator):
@@ -28,6 +29,7 @@ def get_model_name(self):
2829
"""
2930
return self.model_name
3031

32+
@timer
3133
def load_model(self):
3234
"""
3335
Load the F1 transformer model.
@@ -44,12 +46,15 @@ def load_model(self):
4446

4547
# Create the transformer model
4648
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
47-
path_to_load, torch_dtype=torch.bfloat16
49+
path_to_load,
50+
torch_dtype=torch.bfloat16,
51+
quantization_config=self.quantization_config,
4852
).cpu()
4953

5054
# Configure the model
5155
self.transformer.eval()
52-
self.transformer.to(dtype=torch.bfloat16)
56+
if self.quantization_format == QuantizationFormat.brain_floating_point_16bit:
57+
self.transformer.to(dtype=torch.bfloat16)
5358
self.transformer.requires_grad_(False)
5459

5560
# Set up dynamic swap if not in high VRAM mode

modules/generators/model_configuration.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def add_lora_setting(self, setting: ModelLoraSetting) -> None:
304304
@dataclass
305305
class ModelConfiguration:
306306
model_name: str
307+
quantization_format: str
307308
settings: ModelSettings = field(default_factory=ModelSettings)
308309

309310
@property
@@ -330,7 +331,9 @@ def validate(self) -> bool:
330331
return valid
331332

332333
@staticmethod
333-
def from_settings(model_name: str, settings: ModelSettings | dict | None):
334+
def from_settings(
335+
model_name: str, quantization_format: str, settings: ModelSettings | dict | None
336+
):
334337
model_settings: ModelSettings | None = None
335338
if settings is None:
336339
model_settings = ModelSettings()
@@ -344,11 +347,16 @@ def from_settings(model_name: str, settings: ModelSettings | dict | None):
344347
if model_settings is None:
345348
raise ValueError("Invalid config type for ModelConfiguration")
346349

347-
return ModelConfiguration(model_name=model_name, settings=model_settings)
350+
return ModelConfiguration(
351+
model_name=model_name,
352+
quantization_format=quantization_format,
353+
settings=model_settings,
354+
)
348355

349356
@staticmethod
350357
def from_lora_names_and_weights(
351358
model_name: str,
359+
quantization_format: str,
352360
lora_names: list[str],
353361
lora_weights: list[float | int],
354362
lora_loader: str | LoraLoader,
@@ -374,7 +382,9 @@ def from_lora_names_and_weights(
374382
lora_settings=lora_settings, lora_loader=str(lora_loader)
375383
)
376384
return ModelConfiguration.from_settings(
377-
model_name=model_name, settings=model_settings
385+
model_name=model_name,
386+
quantization_format=quantization_format,
387+
settings=model_settings,
378388
)
379389

380390
def set_model_name(self, model_name: str) -> "ModelConfiguration":

modules/generators/original_generator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
from diffusers_helper.memory import DynamicSwapInstaller
66
from .base_generator import BaseModelGenerator
7+
from shared import QuantizationFormat, timer
78

89

910
class OriginalModelGenerator(BaseModelGenerator):
@@ -26,6 +27,7 @@ def get_model_name(self):
2627
"""
2728
return self.model_name
2829

30+
@timer
2931
def load_model(self):
3032
"""
3133
Load the Original transformer model.
@@ -42,12 +44,15 @@ def load_model(self):
4244

4345
# Create the transformer model
4446
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
45-
path_to_load, torch_dtype=torch.bfloat16
47+
path_to_load,
48+
torch_dtype=torch.bfloat16,
49+
quantization_config=self.quantization_config,
4650
).cpu()
4751

4852
# Configure the model
4953
self.transformer.eval()
50-
self.transformer.to(dtype=torch.bfloat16)
54+
if self.quantization_format == QuantizationFormat.brain_floating_point_16bit:
55+
self.transformer.to(dtype=torch.bfloat16)
5156
self.transformer.requires_grad_(False)
5257

5358
# Set up dynamic swap if not in high VRAM mode

modules/generators/video_base_generator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from diffusers_helper.bucket_tools import find_nearest_bucket
1515
from diffusers_helper.hunyuan import vae_encode
1616
from .base_generator import BaseModelGenerator
17+
from shared import QuantizationFormat, timer
1718

1819

1920
class VideoBaseModelGenerator(BaseModelGenerator):
@@ -55,6 +56,7 @@ def get_model_name(self):
5556
"""
5657
return self.model_name
5758

59+
@timer
5860
def load_model(self):
5961
"""
6062
Load the Video transformer model.
@@ -71,12 +73,15 @@ def load_model(self):
7173

7274
# Create the transformer model
7375
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
74-
path_to_load, torch_dtype=torch.bfloat16
76+
path_to_load,
77+
torch_dtype=torch.bfloat16,
78+
quantization_config=self.quantization_config,
7579
).cpu()
7680

7781
# Configure the model
7882
self.transformer.eval()
79-
self.transformer.to(dtype=torch.bfloat16)
83+
if self.quantization_format == QuantizationFormat.brain_floating_point_16bit:
84+
self.transformer.to(dtype=torch.bfloat16)
8085
self.transformer.requires_grad_(False)
8186

8287
# Set up dynamic swap if not in high VRAM mode
@@ -585,8 +590,8 @@ def combine_videos(self, source_video_path, generated_video_path, output_path):
585590
)
586591

587592
# Get the ffmpeg executable from the VideoProcessor class
588-
from modules.toolbox.toolbox_processor import VideoProcessor
589593
from modules.toolbox.message_manager import MessageManager
594+
from modules.toolbox.toolbox_processor import VideoProcessor
590595

591596
# Create a message manager for logging
592597
message_manager = MessageManager()

modules/interface.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,16 @@ def apply_startup_settings():
334334
)
335335
connect_audio_events(a, settings)
336336

337-
def refresh_loras():
337+
def refresh_loras(current_selected):
338338
if enumerate_lora_dir_fn:
339339
new_lora_names = enumerate_lora_dir_fn()
340-
return gr.update(choices=new_lora_names)
340+
preserved = [name for name in (current_selected or []) if name in new_lora_names]
341+
return gr.update(choices=new_lora_names, value=preserved)
341342
return gr.update()
342343

343-
g["refresh_loras_button"].click(fn=refresh_loras, outputs=[g["lora_selector"]])
344+
g["refresh_loras_button"].click(
345+
fn=refresh_loras, inputs=[g["lora_selector"]], outputs=[g["lora_selector"]]
346+
)
344347

345348
# General Connections
346349
def initial_gallery_load():

modules/pipelines/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from modules.llm_enhancer import unload_enhancing_model
3434
from . import create_pipeline
3535
from modules.studio_manager import StudioManager
36+
from shared import timer
3637

3738
# cSpell: disable hunyan, loras
3839

@@ -87,6 +88,7 @@ def get_cached_or_encode_prompt(
8788

8889

8990
@torch.no_grad()
91+
@timer
9092
def worker(
9193
model_type,
9294
input_image,

modules/settings.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from pathlib import Path
33
from typing import Dict, Any, Optional
44
import os
5-
from diffusers_helper.lora_utils_kohya_ss.enums import LoraLoader
5+
6+
from shared import LoraLoader, QuantizationFormat
67

78

89
class Settings:
@@ -45,6 +46,7 @@ def __init__(self):
4546
Enhanced prompt:""",
4647
"lora_loader": LoraLoader.DEFAULT, # lora_loader options: diffusers, lora_ready. DEFAULT is existing behavior of diffusers
4748
"reuse_model_instance": False, # Reuse model instance across generations - default of False is existing behavior
49+
"quantization_format": QuantizationFormat.DEFAULT, # Default quantization format
4850
}
4951
self.settings = self.load_settings()
5052

@@ -73,6 +75,21 @@ def lora_loader(self, value: str | LoraLoader):
7375
def reuse_model_instance(self) -> bool:
7476
return self.settings.get("reuse_model_instance", False)
7577

78+
@property
79+
def quantization_format(self) -> QuantizationFormat:
80+
return QuantizationFormat.safe_parse(
81+
self.settings.get("quantization_format", QuantizationFormat.DEFAULT)
82+
)
83+
84+
@quantization_format.setter
85+
def quantization_format(self, value: str | QuantizationFormat):
86+
if not value:
87+
value = QuantizationFormat.DEFAULT
88+
if isinstance(value, str):
89+
value = QuantizationFormat.safe_parse(value)
90+
91+
self.set("quantization_format", value)
92+
7693
def load_settings(self) -> Dict[str, Any]:
7794
"""Load settings from file or return defaults"""
7895
if self.settings_file.exists():

0 commit comments

Comments
 (0)