@@ -328,21 +328,44 @@ def get_codebook_entry(self, indices, shape):
328
328
329
329
return z_q
330
330
331
+ class EmbeddingEMA (nn .Module ):
332
+ def __init__ (self , num_tokens , codebook_dim , decay = 0.99 , eps = 1e-5 ):
333
+ super ().__init__ ()
334
+ self .decay = decay
335
+ self .eps = eps
336
+ weight = torch .randn (num_tokens , codebook_dim )
337
+ self .weight = nn .Parameter (weight , requires_grad = False )
338
+ self .cluster_size = nn .Parameter (torch .zeros (num_tokens ), requires_grad = False )
339
+ self .embed_avg = nn .Parameter (weight .clone (), requires_grad = False )
340
+ self .update = True
341
+
342
+ def forward (self , embed_id ):
343
+ return F .embedding (embed_id , self .weight )
344
+
345
+ def cluster_size_ema_update (self , new_cluster_size ):
346
+ self .cluster_size .data .mul_ (self .decay ).add_ (new_cluster_size , alpha = 1 - self .decay )
347
+
348
+ def embed_avg_ema_update (self , new_embed_avg ):
349
+ self .embed_avg .data .mul_ (self .decay ).add_ (new_embed_avg , alpha = 1 - self .decay )
350
+
351
+ def weight_update (self , num_tokens ):
352
+ n = self .cluster_size .sum ()
353
+ smoothed_cluster_size = (
354
+ (self .cluster_size + self .eps ) / (n + num_tokens * self .eps ) * n
355
+ )
356
+ #normalize embedding average with smoothed cluster size
357
+ embed_normalized = self .embed_avg / smoothed_cluster_size .unsqueeze (1 )
358
+ self .weight .data .copy_ (embed_normalized )
331
359
332
360
333
361
class EMAVectorQuantizer (nn .Module ):
334
362
def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
335
363
remap = None , unknown_index = "random" ):
336
364
super ().__init__ ()
337
- self .embedding_dim = embedding_dim
338
- self .n_embed = n_embed
339
- self .decay = decay
340
- self .eps = eps
365
+ self .codebook_dim = codebook_dim
366
+ self .num_tokens = num_tokens
341
367
self .beta = beta
342
- self .embedding = nn .Embedding (self .n_embed , self .embedding_dim )
343
- self .embedding .weight .requires_grad = False
344
- self .cluster_size = nn .Parameter (torch .zeros (n_embed ),requires_grad = False )
345
- self .embed_avg = nn .Parameter (torch .randn (self .n_embed , self .embedding_dim ),requires_grad = False )
368
+ self .embedding = EmbeddingEMA (self .num_tokens , self .codebook_dim , decay , eps )
346
369
347
370
self .remap = remap
348
371
if self .remap is not None :
@@ -384,37 +407,31 @@ def unmap_to_all(self, inds):
384
407
def forward (self , z ):
385
408
# reshape z -> (batch, height, width, channel) and flatten
386
409
#z, 'b c h w -> b h w c'
387
- z = z .permute (0 , 2 , 3 , 1 ).contiguous ()
388
- z_flattened = z .view (- 1 , self .embedding_dim )
410
+ z = rearrange (z , 'b c h w -> b h w c' )
411
+ z_flattened = z .reshape (- 1 , self .codebook_dim )
412
+
389
413
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
+ d = z_flattened .pow (2 ).sum (dim = 1 , keepdim = True ) + \
415
+ self .embedding .weight .pow (2 ).sum (dim = 1 ) - 2 * \
416
+ torch .einsum ('bd,nd->bn' , z_flattened , self .embedding .weight ) # 'n d -> d n'
390
417
391
- d = torch .sum (z_flattened .pow (2 ), dim = 1 , keepdim = True ) + \
392
- torch .sum (self .embedding .weight .pow (2 ), dim = 1 ) - 2 * \
393
- torch .einsum ('bd,dn->bn' , z_flattened , self .embedding .weight .permute (1 ,0 )) # 'n d -> d n'
394
418
395
419
encoding_indices = torch .argmin (d , dim = 1 )
420
+
396
421
z_q = self .embedding (encoding_indices ).view (z .shape )
397
- encodings = F .one_hot (encoding_indices , self .n_embed ).type (z .dtype )
422
+ encodings = F .one_hot (encoding_indices , self .num_tokens ).type (z .dtype )
398
423
avg_probs = torch .mean (encodings , dim = 0 )
399
424
perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
400
425
401
- if self .training :
402
- encodings_sum = encodings .sum (0 )
426
+ if self .training and self .embedding .update :
403
427
#EMA cluster size
404
- self .cluster_size .mul_ (self .decay ).add_ (encodings_sum , alpha = 1 - self .decay )
405
-
406
- embed_sum = torch .matmul (encodings .t (), z_flattened )
428
+ encodings_sum = encodings .sum (0 )
429
+ self .embedding .cluster_size_ema_update (encodings_sum )
407
430
#EMA embedding average
408
- self .embed_avg .mul_ (self .decay ).add_ (embed_sum , alpha = 1 - self .decay )
409
-
410
- #cluster size Laplace smoothing
411
- n = self .cluster_size .sum ()
412
- cluster_size = (
413
- (self .cluster_size + self .eps ) / (n + self .n_embed * self .eps ) * n
414
- )
415
- #normalize embedding average with smoothed cluster size
416
- embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
417
- self .embedding .weight .data .copy_ (embed_normalized .data )
431
+ embed_sum = encodings .transpose (0 ,1 ) @ z_flattened
432
+ self .embedding .embed_avg_ema_update (embed_sum )
433
+ #normalize embed_avg and update weight
434
+ self .embedding .weight_update (self .num_tokens )
418
435
419
436
# compute loss for embedding
420
437
loss = self .beta * F .mse_loss (z_q .detach (), z )
@@ -424,68 +441,5 @@ def forward(self, z):
424
441
425
442
# reshape back to match original input shape
426
443
#z_q, 'b h w c -> b c h w'
427
- z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
428
- return z_q , loss , (perplexity , encodings , encoding_indices )
429
-
430
-
431
-
432
- #Original Sonnet version of EMAVectorQuantizer
433
- class EmbeddingEMA (nn .Module ):
434
- def __init__ (self , n_embed , embedding_dim ):
435
- super ().__init__ ()
436
- weight = torch .randn (embedding_dim , n_embed )
437
- self .register_buffer ("weight" , weight )
438
- self .register_buffer ("cluster_size" , torch .zeros (n_embed ))
439
- self .register_buffer ("embed_avg" , weight .clone ())
440
-
441
- def forward (self , embed_id ):
442
- return F .embedding (embed_id , self .weight .transpose (0 , 1 ))
443
-
444
-
445
- class SonnetEMAVectorQuantizer (nn .Module ):
446
- def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
447
- remap = None , unknown_index = "random" ):
448
- super ().__init__ ()
449
- self .embedding_dim = embedding_dim
450
- self .n_embed = n_embed
451
- self .decay = decay
452
- self .eps = eps
453
- self .beta = beta
454
- self .embedding = EmbeddingEMA (n_embed ,embedding_dim )
455
-
456
- def forward (self , z ):
457
- z = z .permute (0 , 2 , 3 , 1 ).contiguous ()
458
- z_flattened = z .reshape (- 1 , self .embedding_dim )
459
- d = (
460
- z_flattened .pow (2 ).sum (1 , keepdim = True )
461
- - 2 * z_flattened @ self .embedding .weight
462
- + self .embedding .weight .pow (2 ).sum (0 , keepdim = True )
463
- )
464
- _ , encoding_indices = (- d ).max (1 )
465
- encodings = F .one_hot (encoding_indices , self .n_embed ).type (z_flattened .dtype )
466
- encoding_indices = encoding_indices .view (* z .shape [:- 1 ])
467
- z_q = self .embedding (encoding_indices )
468
- avg_probs = torch .mean (encodings , dim = 0 )
469
- perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
470
-
471
- if self .training :
472
- encodings_sum = encodings .sum (0 )
473
- embed_sum = z_flattened .transpose (0 , 1 ) @ encodings
474
- #EMA cluster size
475
- self .embedding .cluster_size .data .mul_ (self .decay ).add_ (encodings_sum , alpha = 1 - self .decay )
476
- #EMA embedding average
477
- self .embedding .embed_avg .data .mul_ (self .decay ).add_ (embed_sum , alpha = 1 - self .decay )
478
-
479
- #cluster size Laplace smoothing
480
- n = self .embedding .cluster_size .sum ()
481
- cluster_size = (
482
- (self .embedding .cluster_size + self .eps ) / (n + self .n_embed * self .eps ) * n
483
- )
484
- #normalize embedding average with smoothed cluster size
485
- embed_normalized = self .embedding .embed_avg / cluster_size .unsqueeze (0 )
486
- self .embedding .weight .data .copy_ (embed_normalized )
487
-
488
- loss = self .beta * (z_q .detach () - z ).pow (2 ).mean ()
489
- z_q = z + (z_q - z ).detach ()
490
- z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
444
+ z_q = rearrange (z_q , 'b h w c -> b c h w' )
491
445
return z_q , loss , (perplexity , encodings , encoding_indices )
0 commit comments