Skip to content

Commit e19918f

Browse files
authored
Update the initialisation of init_state and learned_ema_beta to match the original repo
1 parent b22f552 commit e19918f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ def __init__(
330330
self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)
331331

332332
self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))
333+
torch.nn.init.normal_(self.init_state, 0, .1)
333334
self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))
335+
# NOTE: the state position id embeddings are drawn from N(0,1) since they are added after a layer norm
336+
torch.nn.init.normal_(self.state_pos_ids, 0, 1)
334337

335338
self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)
336339

@@ -343,6 +346,7 @@ def __init__(
343346

344347
self.state_out_to_gate = nn.Linear(dim, dim)
345348
self.learned_ema_beta = nn.Parameter(torch.randn(dim))
349+
torch.nn.init.normal_(self.learned_ema_beta, 0, .1)
346350

347351
# since each read should be followed by a write, just store cache in the container
348352

0 commit comments

Comments
 (0)