diff --git a/comfy/model_base.py b/comfy/model_base.py index b0b9cde7d087..8274c7dea192 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -138,6 +138,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) + self.diffusion_model.eval() if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") @@ -669,7 +670,6 @@ def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None): class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageC) - self.diffusion_model.eval().requires_grad_(False) def extra_conds(self, **kwargs): out = {} @@ -698,7 +698,6 @@ def extra_conds(self, **kwargs): class StableCascade_B(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageB) - self.diffusion_model.eval().requires_grad_(False) def extra_conds(self, **kwargs): out = {}