diff --git a/open_lm/model.py b/open_lm/model.py index 06f260a1..d0750af0 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -19,6 +19,11 @@ from open_lm.positional_embedding.rotary import RotaryWithCast from open_lm.positional_embedding.llama_rotary import LLaMARotaryWithCast +try: # optional import + from mamba_ssm import MambaLMHeadModel +except ImportError: + MambaLMHeadModel = None + # from openclip _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs @@ -305,20 +310,57 @@ def create_params(args): # These args are managed separately by the argparser # If a parameter is in the model config, regardless of the args, we use the config parameters # If a parameter is not in the model config, we use the args parameter - return Params( - dim=cfg["hidden_dim"], - n_layers=cfg["n_layers"], - n_heads=cfg["n_heads"], - seq_len=cfg["seq_len"], - vocab_size=cfg["vocab_size"], - post_embed_norm=cfg["post_embed_norm"], - weight_tying=cfg["weight_tying"], - norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)), - apply_qk_norm=cfg.get("qk_norm", args.qk_norm), - positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type), - ffn_type=cfg.get("ffn_type", args.ffn_type), - ) + + if "mamba" in args.model: + return { + "d_model": cfg["d_model"], + "n_layer": cfg["n_layer"], + "vocab_size": cfg["vocab_size"], + "seq_len": cfg["seq_len"], + } + else: + return Params( + dim=cfg["hidden_dim"], + n_layers=cfg["n_layers"], + n_heads=cfg["n_heads"], + seq_len=cfg["seq_len"], + vocab_size=cfg["vocab_size"], + post_embed_norm=cfg["post_embed_norm"], + weight_tying=cfg["weight_tying"], + norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)), + apply_qk_norm=cfg.get("qk_norm", args.qk_norm), + positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type), + ffn_type=cfg.get("ffn_type", args.ffn_type), + ) + + +class Mamba(nn.Module): + # Experimental architecture, please "pip install mamba-ssm" + # https://arxiv.org/abs/2312.00752 + def __init__(self, params): + if MambaLMHeadModel is None: + raise ImportError( + "MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'." + ) + + super().__init__() + self.seq_len = params.pop("seq_len") + self.vocab_size = params["vocab_size"] + + self.model = MambaLMHeadModel(**params) + + def reset_parameters(self): + return + + def forward(self, x): + out = self.model(x).logits + return out, None def create_model(args): - return Transformer(create_params(args)) + if "mamba" in args.model: + model = Mamba(create_params(args)) + return model + else: + model = Transformer(create_params(args)) + return model diff --git a/open_lm/model_configs/mamba_130m.json b/open_lm/model_configs/mamba_130m.json new file mode 100644 index 00000000..172c7e5b --- /dev/null +++ b/open_lm/model_configs/mamba_130m.json @@ -0,0 +1,6 @@ +{ + "d_model": 768, + "n_layer": 12, + "vocab_size": 50432, + "seq_len": 2048 +} diff --git a/open_lm/model_configs/mamba_1b.json b/open_lm/model_configs/mamba_1b.json new file mode 100644 index 00000000..6af97f5c --- /dev/null +++ b/open_lm/model_configs/mamba_1b.json @@ -0,0 +1,6 @@ +{ + "d_model": 2048, + "n_layer": 36, + "vocab_size": 50432, + "seq_len": 2048 +}