Skip to content

Commit 6ac7427

Browse files
authored
update EMA with newest
1 parent 6c10d9b commit 6ac7427

File tree

1 file changed

+47
-93
lines changed

1 file changed

+47
-93
lines changed

taming/modules/vqvae/quantize.py

Lines changed: 47 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -328,21 +328,44 @@ def get_codebook_entry(self, indices, shape):
328328

329329
return z_q
330330

331+
class EmbeddingEMA(nn.Module):
332+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333+
super().__init__()
334+
self.decay = decay
335+
self.eps = eps
336+
weight = torch.randn(num_tokens, codebook_dim)
337+
self.weight = nn.Parameter(weight, requires_grad = False)
338+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340+
self.update = True
341+
342+
def forward(self, embed_id):
343+
return F.embedding(embed_id, self.weight)
344+
345+
def cluster_size_ema_update(self, new_cluster_size):
346+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347+
348+
def embed_avg_ema_update(self, new_embed_avg):
349+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350+
351+
def weight_update(self, num_tokens):
352+
n = self.cluster_size.sum()
353+
smoothed_cluster_size = (
354+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355+
)
356+
#normalize embedding average with smoothed cluster size
357+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358+
self.weight.data.copy_(embed_normalized)
331359

332360

333361
class EMAVectorQuantizer(nn.Module):
334362
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
335363
remap=None, unknown_index="random"):
336364
super().__init__()
337-
self.embedding_dim = embedding_dim
338-
self.n_embed = n_embed
339-
self.decay = decay
340-
self.eps = eps
365+
self.codebook_dim = codebook_dim
366+
self.num_tokens = num_tokens
341367
self.beta = beta
342-
self.embedding = nn.Embedding(self.n_embed, self.embedding_dim)
343-
self.embedding.weight.requires_grad = False
344-
self.cluster_size = nn.Parameter(torch.zeros(n_embed),requires_grad=False)
345-
self.embed_avg = nn.Parameter(torch.randn(self.n_embed, self.embedding_dim),requires_grad=False)
368+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
346369

347370
self.remap = remap
348371
if self.remap is not None:
@@ -384,37 +407,31 @@ def unmap_to_all(self, inds):
384407
def forward(self, z):
385408
# reshape z -> (batch, height, width, channel) and flatten
386409
#z, 'b c h w -> b h w c'
387-
z = z.permute(0, 2, 3, 1).contiguous()
388-
z_flattened = z.view(-1, self.embedding_dim)
410+
z = rearrange(z, 'b c h w -> b h w c')
411+
z_flattened = z.reshape(-1, self.codebook_dim)
412+
389413
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
390417

391-
d = torch.sum(z_flattened.pow(2), dim=1, keepdim=True) + \
392-
torch.sum(self.embedding.weight.pow(2), dim=1) - 2 * \
393-
torch.einsum('bd,dn->bn', z_flattened, self.embedding.weight.permute(1,0)) # 'n d -> d n'
394418

395419
encoding_indices = torch.argmin(d, dim=1)
420+
396421
z_q = self.embedding(encoding_indices).view(z.shape)
397-
encodings = F.one_hot(encoding_indices, self.n_embed).type(z.dtype)
422+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
398423
avg_probs = torch.mean(encodings, dim=0)
399424
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
400425

401-
if self.training:
402-
encodings_sum = encodings.sum(0)
426+
if self.training and self.embedding.update:
403427
#EMA cluster size
404-
self.cluster_size.mul_(self.decay).add_(encodings_sum, alpha=1 - self.decay)
405-
406-
embed_sum = torch.matmul(encodings.t(), z_flattened)
428+
encodings_sum = encodings.sum(0)
429+
self.embedding.cluster_size_ema_update(encodings_sum)
407430
#EMA embedding average
408-
self.embed_avg.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
409-
410-
#cluster size Laplace smoothing
411-
n = self.cluster_size.sum()
412-
cluster_size = (
413-
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
414-
)
415-
#normalize embedding average with smoothed cluster size
416-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
417-
self.embedding.weight.data.copy_(embed_normalized.data)
431+
embed_sum = encodings.transpose(0,1) @ z_flattened
432+
self.embedding.embed_avg_ema_update(embed_sum)
433+
#normalize embed_avg and update weight
434+
self.embedding.weight_update(self.num_tokens)
418435

419436
# compute loss for embedding
420437
loss = self.beta * F.mse_loss(z_q.detach(), z)
@@ -424,68 +441,5 @@ def forward(self, z):
424441

425442
# reshape back to match original input shape
426443
#z_q, 'b h w c -> b c h w'
427-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
428-
return z_q, loss, (perplexity, encodings, encoding_indices)
429-
430-
431-
432-
#Original Sonnet version of EMAVectorQuantizer
433-
class EmbeddingEMA(nn.Module):
434-
def __init__(self, n_embed, embedding_dim):
435-
super().__init__()
436-
weight = torch.randn(embedding_dim, n_embed)
437-
self.register_buffer("weight", weight)
438-
self.register_buffer("cluster_size", torch.zeros(n_embed))
439-
self.register_buffer("embed_avg", weight.clone())
440-
441-
def forward(self, embed_id):
442-
return F.embedding(embed_id, self.weight.transpose(0, 1))
443-
444-
445-
class SonnetEMAVectorQuantizer(nn.Module):
446-
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
447-
remap=None, unknown_index="random"):
448-
super().__init__()
449-
self.embedding_dim = embedding_dim
450-
self.n_embed = n_embed
451-
self.decay = decay
452-
self.eps = eps
453-
self.beta = beta
454-
self.embedding = EmbeddingEMA(n_embed,embedding_dim)
455-
456-
def forward(self, z):
457-
z = z.permute(0, 2, 3, 1).contiguous()
458-
z_flattened = z.reshape(-1, self.embedding_dim)
459-
d = (
460-
z_flattened.pow(2).sum(1, keepdim=True)
461-
- 2 * z_flattened @ self.embedding.weight
462-
+ self.embedding.weight.pow(2).sum(0, keepdim=True)
463-
)
464-
_, encoding_indices = (-d).max(1)
465-
encodings = F.one_hot(encoding_indices, self.n_embed).type(z_flattened.dtype)
466-
encoding_indices = encoding_indices.view(*z.shape[:-1])
467-
z_q = self.embedding(encoding_indices)
468-
avg_probs = torch.mean(encodings, dim=0)
469-
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
470-
471-
if self.training:
472-
encodings_sum = encodings.sum(0)
473-
embed_sum = z_flattened.transpose(0, 1) @ encodings
474-
#EMA cluster size
475-
self.embedding.cluster_size.data.mul_(self.decay).add_(encodings_sum, alpha=1 - self.decay)
476-
#EMA embedding average
477-
self.embedding.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
478-
479-
#cluster size Laplace smoothing
480-
n = self.embedding.cluster_size.sum()
481-
cluster_size = (
482-
(self.embedding.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
483-
)
484-
#normalize embedding average with smoothed cluster size
485-
embed_normalized = self.embedding.embed_avg / cluster_size.unsqueeze(0)
486-
self.embedding.weight.data.copy_(embed_normalized)
487-
488-
loss = self.beta * (z_q.detach() - z).pow(2).mean()
489-
z_q = z + (z_q - z).detach()
490-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
444+
z_q = rearrange(z_q, 'b h w c -> b c h w')
491445
return z_q, loss, (perplexity, encodings, encoding_indices)

0 commit comments

Comments
 (0)