Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions retnet/configuration_retnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class RetNetConfig(PretrainedConfig):
use_ffn_rms_norm: bool = False
layernorm_eps: float = 1e-6
tie_word_embeddings: bool = False

def __init__(
self,
vocab_size: int = 50257,
Expand All @@ -60,14 +60,15 @@ def __init__(
layernorm_embedding: bool = False, # add layernorm to embedding
no_scale_embedding: bool = True, # if True, dont scale embeddings
recurrent_chunk_size: int = 512,
use_glu: bool = True, # use GLU instead of FFN
use_glu: bool = False, # use GLU instead of FFN
z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
use_lm_decay: bool = False,
deepnorm: bool = False,
subln: bool = True,
use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN
layernorm_eps: float = 1e-6,
tie_word_embeddings: bool = False,
use_flash_retention: bool = False,
**kwargs):
self.vocab_size = vocab_size
self.initializer_range = initializer_range
Expand Down Expand Up @@ -96,7 +97,7 @@ def __init__(
# Blockwise
self.recurrent_chunk_size = recurrent_chunk_size
self.forward_impl = forward_impl

self.use_flash_retention = use_flash_retention
if self.deepnorm:
self.decoder_normalize_before = False
self.subln = False
Expand All @@ -114,4 +115,4 @@ def __init__(
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)
self.__dict__[hp] = getattr(args, hp, None)
Loading