Skip to content

Commit

Permalink
model.py: Mamba SSM (#137)
Browse files Browse the repository at this point in the history
* Mamba support

* clean up

* remove changes

* black fix

* wraps

* update comment

* raise import error

* black fix

---------

Co-authored-by: Maciej Kilian <[email protected]>
Co-authored-by: Maciej Kilian <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2023
1 parent 44e1f74 commit 158124d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 14 deletions.
70 changes: 56 additions & 14 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from open_lm.positional_embedding.rotary import RotaryWithCast
from open_lm.positional_embedding.llama_rotary import LLaMARotaryWithCast

try: # optional import
from mamba_ssm import MambaLMHeadModel
except ImportError:
MambaLMHeadModel = None

# from openclip
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
Expand Down Expand Up @@ -305,20 +310,57 @@ def create_params(args):
# These args are managed separately by the argparser
# If a parameter is in the model config, regardless of the args, we use the config parameters
# If a parameter is not in the model config, we use the args parameter
return Params(
dim=cfg["hidden_dim"],
n_layers=cfg["n_layers"],
n_heads=cfg["n_heads"],
seq_len=cfg["seq_len"],
vocab_size=cfg["vocab_size"],
post_embed_norm=cfg["post_embed_norm"],
weight_tying=cfg["weight_tying"],
norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)),
apply_qk_norm=cfg.get("qk_norm", args.qk_norm),
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type),
ffn_type=cfg.get("ffn_type", args.ffn_type),
)

if "mamba" in args.model:
return {
"d_model": cfg["d_model"],
"n_layer": cfg["n_layer"],
"vocab_size": cfg["vocab_size"],
"seq_len": cfg["seq_len"],
}
else:
return Params(
dim=cfg["hidden_dim"],
n_layers=cfg["n_layers"],
n_heads=cfg["n_heads"],
seq_len=cfg["seq_len"],
vocab_size=cfg["vocab_size"],
post_embed_norm=cfg["post_embed_norm"],
weight_tying=cfg["weight_tying"],
norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)),
apply_qk_norm=cfg.get("qk_norm", args.qk_norm),
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type),
ffn_type=cfg.get("ffn_type", args.ffn_type),
)


class Mamba(nn.Module):
# Experimental architecture, please "pip install mamba-ssm"
# https://arxiv.org/abs/2312.00752
def __init__(self, params):
if MambaLMHeadModel is None:
raise ImportError(
"MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'."
)

super().__init__()
self.seq_len = params.pop("seq_len")
self.vocab_size = params["vocab_size"]

self.model = MambaLMHeadModel(**params)

def reset_parameters(self):
return

def forward(self, x):
out = self.model(x).logits
return out, None


def create_model(args):
return Transformer(create_params(args))
if "mamba" in args.model:
model = Mamba(create_params(args))
return model
else:
model = Transformer(create_params(args))
return model
6 changes: 6 additions & 0 deletions open_lm/model_configs/mamba_130m.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"d_model": 768,
"n_layer": 12,
"vocab_size": 50432,
"seq_len": 2048
}
6 changes: 6 additions & 0 deletions open_lm/model_configs/mamba_1b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"d_model": 2048,
"n_layer": 36,
"vocab_size": 50432,
"seq_len": 2048
}

0 comments on commit 158124d

Please sign in to comment.