Skip to content

Commit 742f30b

Browse files
authored
Merge pull request #2611 from bghira/bugfix/flux2-ramtorch
flux2: fix ramtorch validation by checking device location correctly
2 parents 0282f0b + 14b1ba9 commit 742f30b

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

simpletuner/helpers/models/wan/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1357,10 +1357,11 @@ def _fixed_execution_device(self):
13571357
"""
13581358
Fixed _execution_device property that returns the transformer device instead of meta.
13591359
This fixes the issue when text encoder is moved to meta but transformer is on GPU.
1360+
Uses .device property (not raw parameter device) so ramtorch-aware patches apply.
13601361
"""
13611362
# If we have a transformer and it's not on meta, use its device
13621363
if hasattr(self, "transformer") and self.transformer is not None:
1363-
transformer_device = next(self.transformer.parameters()).device
1364+
transformer_device = self.transformer.device
13641365
if transformer_device.type != "meta":
13651366
return transformer_device
13661367

simpletuner/helpers/ramtorch_extensions.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,3 +789,51 @@ def remove_ramtorch_sync_hooks(hooks: list) -> None:
789789
"""Remove synchronization hooks added by add_ramtorch_sync_hooks."""
790790
for h in hooks:
791791
h.remove()
792+
793+
794+
def get_ramtorch_target_device(model: nn.Module) -> torch.device | None:
795+
"""Return the target GPU device from a model's ramtorch modules, or None.
796+
797+
Returns the device of the first ramtorch module found. All ramtorch
798+
modules within a model share the same target device because
799+
``replace_linear_layers_with_ramtorch`` applies a single ``device``
800+
argument to every replaced layer.
801+
802+
Returns ``None`` when the model contains no ramtorch modules.
803+
"""
804+
for m in model.modules():
805+
if getattr(m, "is_ramtorch", False):
806+
dev = m.device
807+
return torch.device(dev) if isinstance(dev, str) else dev
808+
return None
809+
810+
811+
_model_device_patched = False
812+
813+
814+
def patch_model_device_for_ramtorch():
815+
"""
816+
Patch ModelMixin.device so it returns the ramtorch target GPU device
817+
instead of CPU when ramtorch modules are present.
818+
819+
This single patch fixes:
820+
- DiffusionPipeline._execution_device (delegates to pipeline.device -> model.device)
821+
- Direct self.transformer.device / self.unet.device references in pipeline code
822+
"""
823+
global _model_device_patched
824+
if _model_device_patched:
825+
return
826+
_model_device_patched = True
827+
828+
from diffusers import ModelMixin
829+
830+
original_device = ModelMixin.device
831+
832+
@property
833+
def device(self) -> torch.device:
834+
dev = get_ramtorch_target_device(self)
835+
if dev is not None:
836+
return dev
837+
return original_device.fget(self)
838+
839+
ModelMixin.device = device

simpletuner/helpers/training/validation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,13 @@ def setup_pipeline(self, validation_type):
21312131
if getattr(te, "device", None) and te.device.type == "meta":
21322132
setattr(self.model.pipeline, attr, None)
21332133

2134+
# Patch ModelMixin.device so ramtorch models report the target GPU instead of CPU.
2135+
# Must run before pipeline.to() and pipeline.__call__ which rely on device detection.
2136+
if getattr(self.config, "ramtorch", False):
2137+
from simpletuner.helpers.ramtorch_extensions import patch_model_device_for_ramtorch
2138+
2139+
patch_model_device_for_ramtorch()
2140+
21342141
# For FSDP models, skip .to() call - DTensor parameters are already device-aware
21352142
# and calling .to() causes: "RuntimeError: Attempted to set the storage of a tensor
21362143
# on device 'cpu' to a storage on different device 'cuda:0'"

0 commit comments

Comments
 (0)