-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelSetupDiffusionLossMixin.py
359 lines (307 loc) · 14.7 KB
/
ModelSetupDiffusionLossMixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
from abc import ABCMeta
from typing import Callable
from modules.module.AestheticScoreModel import AestheticScoreModel
from modules.module.HPSv2ScoreModel import HPSv2ScoreModel
from modules.util.config.TrainConfig import TrainConfig
from modules.util.DiffusionScheduleCoefficients import DiffusionScheduleCoefficients
from modules.util.enum.AlignPropLoss import AlignPropLoss
from modules.util.enum.LossScaler import LossScaler
from modules.util.enum.LossWeight import LossWeight
from modules.util.loss.masked_loss import masked_losses
from modules.util.loss.vb_loss import vb_losses
import torch
import torch.nn.functional as F
from torch import Tensor
class ModelSetupDiffusionLossMixin(metaclass=ABCMeta):
__coefficients: DiffusionScheduleCoefficients | None
__alphas_cumprod_fun: Callable[[Tensor, int], Tensor] | None
def __init__(self):
super(ModelSetupDiffusionLossMixin, self).__init__()
self.__align_prop_loss_fn = None
self.__coefficients = None
self.__alphas_cumprod_fun = None
def __align_prop_losses(
self,
batch: dict,
data: dict,
config: TrainConfig,
train_device: torch.device,
):
if self.__align_prop_loss_fn is None:
dtype = data["predicted"].dtype
match config.align_prop_loss:
case AlignPropLoss.HPS:
self.__align_prop_loss_fn = HPSv2ScoreModel(dtype)
case AlignPropLoss.AESTHETIC:
self.__align_prop_loss_fn = AestheticScoreModel()
self.__align_prop_loss_fn.to(device=train_device, dtype=dtype)
self.__align_prop_loss_fn.requires_grad_(False)
self.__align_prop_loss_fn.eval()
losses = 0
match config.align_prop_loss:
case AlignPropLoss.HPS:
with torch.autocast(device_type=train_device.type, dtype=data["predicted"].dtype):
losses = self.__align_prop_loss_fn(data["predicted"], batch["prompt"], train_device)
case AlignPropLoss.AESTHETIC:
losses = self.__align_prop_loss_fn(data["predicted"])
return losses * config.align_prop_weight
def __log_cosh_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
):
diff = pred - target
loss = (
diff
+ torch.nn.functional.softplus(-2.0 * diff)
- torch.log(torch.full(size=diff.size(), fill_value=2.0, dtype=torch.float32, device=diff.device))
)
return loss
def __masked_losses(
self,
batch: dict,
data: dict,
config: TrainConfig,
):
losses = 0
# MSE/L2 Loss
if config.mse_strength != 0:
losses += (
masked_losses(
losses=F.mse_loss(data["predicted"].to(dtype=torch.float32), data["target"].to(dtype=torch.float32), reduction="none"),
mask=batch["latent_mask"].to(dtype=torch.float32),
unmasked_weight=config.unmasked_weight,
normalize_masked_area_loss=config.normalize_masked_area_loss,
).mean([1, 2, 3])
* config.mse_strength
)
# MAE/L1 Loss
if config.mae_strength != 0:
losses += (
masked_losses(
losses=F.l1_loss(data["predicted"].to(dtype=torch.float32), data["target"].to(dtype=torch.float32), reduction="none"),
mask=batch["latent_mask"].to(dtype=torch.float32),
unmasked_weight=config.unmasked_weight,
normalize_masked_area_loss=config.normalize_masked_area_loss,
).mean([1, 2, 3])
* config.mae_strength
)
# log-cosh Loss
if config.log_cosh_strength != 0:
losses += (
masked_losses(
losses=self.__log_cosh_loss(data["predicted"].to(dtype=torch.float32), data["target"].to(dtype=torch.float32)),
mask=batch["latent_mask"].to(dtype=torch.float32),
unmasked_weight=config.unmasked_weight,
normalize_masked_area_loss=config.normalize_masked_area_loss,
).mean([1, 2, 3])
* config.log_cosh_strength
)
# VB loss
if config.vb_loss_strength != 0 and "predicted_var_values" in data and self.__coefficients is not None:
losses += (
masked_losses(
losses=vb_losses(
coefficients=self.__coefficients,
x_0=data["scaled_latent_image"].to(dtype=torch.float32),
x_t=data["noisy_latent_image"].to(dtype=torch.float32),
t=data["timestep"],
predicted_eps=data["predicted"].to(dtype=torch.float32),
predicted_var_values=data["predicted_var_values"].to(dtype=torch.float32),
),
mask=batch["latent_mask"].to(dtype=torch.float32),
unmasked_weight=config.unmasked_weight,
normalize_masked_area_loss=config.normalize_masked_area_loss,
).mean([1, 2, 3])
* config.vb_loss_strength
)
return losses
def __unmasked_losses(
self,
batch: dict,
data: dict,
config: TrainConfig,
):
losses = 0
# MSE/L2 Loss
if config.mse_strength != 0:
losses += (
F.mse_loss(data["predicted"].to(dtype=torch.float32), data["target"].to(dtype=torch.float32), reduction="none").mean(
[1, 2, 3]
)
* config.mse_strength
)
# MAE/L1 Loss
if config.mae_strength != 0:
losses += (
F.l1_loss(data["predicted"].to(dtype=torch.float32), data["target"].to(dtype=torch.float32), reduction="none").mean(
[1, 2, 3]
)
* config.mae_strength
)
# log-cosh Loss
if config.log_cosh_strength != 0:
losses += (
self.__log_cosh_loss(data["predicted"].to(dtype=torch.float32), data["target"].to(dtype=torch.float32)).mean([1, 2, 3])
* config.log_cosh_strength
)
# VB loss
if config.vb_loss_strength != 0 and "predicted_var_values" in data:
losses += (
vb_losses(
coefficients=self.__coefficients,
x_0=data["scaled_latent_image"].to(dtype=torch.float32),
x_t=data["noisy_latent_image"].to(dtype=torch.float32),
t=data["timestep"],
predicted_eps=data["predicted"].to(dtype=torch.float32),
predicted_var_values=data["predicted_var_values"].to(dtype=torch.float32),
).mean([1, 2, 3])
* config.vb_loss_strength
)
if config.masked_training and config.normalize_masked_area_loss:
clamped_mask = torch.clamp(batch["latent_mask"], config.unmasked_weight, 1)
mask_mean = clamped_mask.mean(dim=(1, 2, 3))
losses /= mask_mean
return losses
def __snr(self, timesteps: Tensor, device: torch.device):
if self.__coefficients:
all_snr = (self.__coefficients.sqrt_alphas_cumprod / self.__coefficients.sqrt_one_minus_alphas_cumprod) ** 2
all_snr.to(device)
snr = all_snr[timesteps]
else:
alphas_cumprod = self.__alphas_cumprod_fun(timesteps, 1)
snr = alphas_cumprod / (1.0 - alphas_cumprod)
return snr
"""
This is where the __min_snr_weight function was originally, but because I didn't use it,
I replaced it with my custom loss function, for laziness and convenience. But if you know
what you're doing, you can put it in any other place, or even modify the OT code enough
to add the function and activate it through the UI.
"""
def __sangoi_loss_modifier(self, timesteps: Tensor, predicted: Tensor, target: Tensor, gamma: float, device: torch.device) -> Tensor:
"""
Source: https://github.com/sangoi-exe/sangoi-loss-function
Computes a loss modifier based on the Mean Absolute Percentage Error (MAPE) and the Signal-to-Noise Ratio (SNR).
This modifier adjusts the loss according to the prediction accuracy and the difficulty of the prediction task.
Args:
timesteps (Tensor): The current training step's timesteps.
predicted (Tensor): Predicted values from the neural network.
target (Tensor): Ground truth target values.
gamma (float): A scaling factor (unused in this function).
device (torch.device): The device on which tensors are allocated.
Returns:
Tensor: A tensor of weights per example to modify the loss.
"""
# Define minimum and maximum SNR values to clamp extreme values
min_snr = 1e-4
max_snr = 100
# Obtain the SNR for each timestep
snr = self.__snr(timesteps, device)
# Clamp the SNR values to the defined range to avoid extreme values
snr = torch.clamp(snr, min=min_snr, max=max_snr)
# Define a small epsilon to prevent division by zero
epsilon = 1e-8
# Compute the Mean Absolute Percentage Error (MAPE)
mape = torch.abs((target - predicted) / (target + epsilon))
# Normalize MAPE values between 0 and 1
mape = torch.clamp(mape, min=0, max=1)
# Calculate the average MAPE per example across spatial dimensions
mape = mape.mean(dim=[1, 2, 3])
# Compute the SNR weight using the natural logarithm (adding 1 to avoid log(0))
snr_weight = torch.log(snr + 1)
# Invert MAPE to represent accuracy instead of error
mape_reward = 1 - mape
# Calculate the combined weight using the negative exponential of the product of MAPE reward and SNR weight
combined_weight = torch.exp(-mape_reward * snr_weight)
# Return the tensor of weights per example to modify the loss
return combined_weight
def __debiased_estimation_weight(self, timesteps: Tensor, v_prediction: bool, device: torch.device) -> Tensor:
snr = self.__snr(timesteps, device)
weight = snr
# The line below is a departure from the original paper.
# This is to match the Kohya implementation, see: https://github.com/kohya-ss/sd-scripts/pull/889
# In addition, it helps avoid numerical instability.
torch.clip(weight, max=1.0e3, out=weight)
if v_prediction:
weight += 1.0
torch.rsqrt(weight, out=weight)
return weight
def __p2_loss_weight(
self,
timesteps: Tensor,
gamma: float,
v_prediction: bool,
device: torch.device,
) -> Tensor:
snr = self.__snr(timesteps, device)
if v_prediction:
snr += 1.0
return (1.0 + snr) ** -gamma
def _diffusion_losses(
self,
batch: dict,
data: dict,
config: TrainConfig,
train_device: torch.device,
betas: Tensor | None = None,
alphas_cumprod_fun: Callable[[Tensor, int], Tensor] | None = None,
) -> Tensor:
loss_weight = batch["loss_weight"]
batch_size_scale = 1 if config.loss_scaler in [LossScaler.NONE, LossScaler.GRADIENT_ACCUMULATION] else config.batch_size
gradient_accumulation_steps_scale = (
1 if config.loss_scaler in [LossScaler.NONE, LossScaler.BATCH] else config.gradient_accumulation_steps
)
if self.__coefficients is None and betas is not None:
self.__coefficients = DiffusionScheduleCoefficients.from_betas(betas)
self.__alphas_cumprod_fun = alphas_cumprod_fun
if data["loss_type"] == "align_prop":
losses = self.__align_prop_losses(batch, data, config, train_device)
else:
# TODO: don't disable masked loss functions when has_conditioning_image_input is true.
# This breaks if only the VAE is trained, but was loaded from an inpainting checkpoint
if config.masked_training and not config.model_type.has_conditioning_image_input():
losses = self.__masked_losses(batch, data, config)
else:
losses = self.__unmasked_losses(batch, data, config)
# Scale Losses by Batch and/or GA (if enabled)
losses = losses * batch_size_scale * gradient_accumulation_steps_scale
losses *= loss_weight.to(device=losses.device, dtype=losses.dtype)
# Apply timestep based loss weighting.
if "timestep" in data and data["loss_type"] != "align_prop":
v_pred = data.get("prediction_type", "") == "v_prediction"
match config.loss_weight_fn:
case LossWeight.MIN_SNR_GAMMA:
losses *= self.__sangoi_loss_modifier(
data["timestep"], data["predicted"], data["target"], config.loss_weight_strength, losses.device
)
case LossWeight.DEBIASED_ESTIMATION:
losses *= self.__debiased_estimation_weight(data["timestep"], v_pred, losses.device)
case LossWeight.P2:
losses *= self.__p2_loss_weight(data["timestep"], config.loss_weight_strength, v_pred, losses.device)
return losses
def _flow_matching_losses(
self,
batch: dict,
data: dict,
config: TrainConfig,
train_device: torch.device,
sigmas: Tensor | None = None,
) -> Tensor:
loss_weight = batch["loss_weight"]
batch_size_scale = 1 if config.loss_scaler in [LossScaler.NONE, LossScaler.GRADIENT_ACCUMULATION] else config.batch_size
gradient_accumulation_steps_scale = (
1 if config.loss_scaler in [LossScaler.NONE, LossScaler.BATCH] else config.gradient_accumulation_steps
)
if data["loss_type"] == "align_prop":
losses = self.__align_prop_losses(batch, data, config, train_device)
else:
# TODO: don't disable masked loss functions when has_conditioning_image_input is true.
# This breaks if only the VAE is trained, but was loaded from an inpainting checkpoint
if config.masked_training and not config.model_type.has_conditioning_image_input():
losses = self.__masked_losses(batch, data, config)
else:
losses = self.__unmasked_losses(batch, data, config)
# Scale Losses by Batch and/or GA (if enabled)
losses = losses * batch_size_scale * gradient_accumulation_steps_scale
losses *= loss_weight.to(device=losses.device, dtype=losses.dtype)
return losses