Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixture of Experts #115

Merged
merged 63 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
8fcc9b9
update
Nov 21, 2023
7093f83
works on one gpu
Nov 25, 2023
b6ba69f
added moe params
Nov 26, 2023
8b4c230
works on multiple gpus
Nov 27, 2023
45c1849
update
Nov 27, 2023
66359d9
eval works
Nov 27, 2023
7466d60
update
Nov 27, 2023
440f9f7
update
Nov 27, 2023
5c37fdb
update
Nov 27, 2023
429a780
update
Nov 27, 2023
930174f
update
Nov 27, 2023
a302c66
removed experiment dir
Nov 27, 2023
6483ef2
removed experiments dir
Nov 27, 2023
d0ea8dc
removed custom fsdp
Nov 27, 2023
4eb3dd0
update
Nov 27, 2023
a46a2fe
update
Nov 27, 2023
f5421f6
added expert gradient norm
Nov 29, 2023
e254615
update
Dec 12, 2023
0912ba8
update
Dec 12, 2023
0a4d0a3
Update model.py
kernelmachine Dec 12, 2023
53905bc
added load balancing loss
Dec 12, 2023
a4d0efc
Merge branch 'moe' of github.com:mlfoundations/open_lm into moe
Dec 12, 2023
52e2cfd
update
Dec 12, 2023
ebacdbe
update
Dec 13, 2023
f404b0d
update
Dec 13, 2023
2efdf17
update
Dec 13, 2023
88156ca
update
Dec 14, 2023
e783e9e
update
Dec 14, 2023
f027203
update
Dec 14, 2023
5c70426
fixed merge conflicts
Dec 15, 2023
a6be666
removed unnecessary dir
Dec 15, 2023
39124e7
Remove the now ignored directory experiments
Dec 15, 2023
6196cbc
update
Dec 15, 2023
560250a
update
Dec 15, 2023
c904356
update
Dec 15, 2023
525083c
update
Dec 15, 2023
25e6948
update
Dec 15, 2023
4849f2a
update
Dec 15, 2023
4c41b61
update
Dec 15, 2023
5c32f64
black formatting
Dec 15, 2023
78ac356
removed custom moe dir
Dec 15, 2023
24cbb9b
update
Dec 15, 2023
ec8cd0f
update
Dec 15, 2023
427bc39
update
Dec 15, 2023
64e512b
update
Dec 15, 2023
cbfa187
update
Dec 15, 2023
90648c9
update
Dec 15, 2023
32cba9d
update
Dec 15, 2023
da9d4a7
update
Dec 15, 2023
5741fde
Merge branch 'main' into moe
kernelmachine Dec 15, 2023
a70609b
Merge branch 'main' of github.com:mlfoundations/open_lm into moe
Dec 15, 2023
e722226
update
Dec 15, 2023
5980bf8
Merge branch 'moe' of github.com:mlfoundations/open_lm into moe
Dec 15, 2023
ce3c838
update
Dec 16, 2023
8b7c77d
update
Dec 16, 2023
8daf040
update
Dec 16, 2023
dd552d8
update
Dec 16, 2023
1c9ff6f
update
Dec 16, 2023
9e2aa30
update
Dec 16, 2023
5e78109
update
Dec 18, 2023
7ea0385
update
Dec 18, 2023
7b68747
update
Dec 18, 2023
a50fed2
Merge branch 'main' into moe
kernelmachine Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ out*
tests/assets/*
.vscode/
checkpoints/
experiments/
72 changes: 72 additions & 0 deletions MOE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Mixture of Experts Language Models

## Dependencies

```
pip install megablocks
pip install xformers==0.0.22.post4
```

## Train MoE

To train an MoE, add the `--moe-X` related arguments to the training command:

```
torchrun --nproc-per-node 8 -m open_lm.main \
--train-num-samples 1000000000 \
--workers 2 \
--dataset-manifest "s3://laion-west/rpj_tokenized_upsampled_eleutherai/manifest.jsonl" "s3://laion-west/2T_no_rpj_tokenized_upsampled_25k_shards/manifest.jsonl" \
--train-data-mix-weights 0.725 0.275 \
--precision amp_bfloat16 \
--batch-size 8 \
--log-every-n-steps 20 \
--grad-clip-norm 1 \
--lr 6e-4 \
--warmup 200 \
--model aphid_neox \
--wd 0.01 \
--beta2 0.95 \
--epochs 4 \
--report-to wandb \
--wandb-project-name moe \
--name test_moe \
--logs /fsx/home-$USER/experiments/moe \
--resume latest \
--seed 124 \
--data-key 'json' \
--accum-freq 4 \
--model-norm gain_only_layer_norm \
--fsdp --fsdp-amp \
--lr-cooldown-end 1e-5 \
--no-skip-tokens \
--accurate-total-tokens \
--moe-freq 2 \
--moe-num-experts 8 \
--moe-top-k 2 \
--moe-capacity-factor 1.25 --moe-loss-weight 0.1
```

The above command will add an MoE FFN layer to every other Transformer block. You can use an arbitrary number of experts; you are only limited by total RAM across all GPUs.


You can also add the `moe_expert_model_parallelism` which will distribute experts across different GPUs. However, if the #GPU is larger than #Expert, additional #GPU/#Expert tensor parallelism is applied. Currently this is not eval-friendly though, so I would not recommend using it yet.

You can evaluate the MoE in the same way as dense models:

```
torchrun --nproc-per-node 8 -m open_lm.main \
--val-data "pipe:aws s3 cp s3://laion-west/lmdata/validation_data_tokenized/open_lm//shard_00000000.tar -" \
--workers 6 \
--precision amp_bfloat16 \
--batch-size 8 \
--log-every-n-steps 1 \
--model open_lm_41m \
--fsdp --fsdp-amp \
--moe-num-experts 64 --moe-freq 2 \
--data-key json \
--train-num-samples 1000000000 \
--model-norm gain_only_layer_norm \
--name $RANDOM \
--resume /fsx/home-suching/experiments/mix_wo/test8086/checkpoints/epoch_1.pt \
--logs /fsx/home-$USER/experiments/eval
```
6 changes: 5 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,11 @@ def main(args):
if args.hf_model is not None:
model = create_wrapped_hf_model(args)
else:
with torch.device("meta" if args.fsdp else args.device):
# with torch.device("meta" if args.fsdp else args.device):
# model = create_model(args)
# if not args.fsdp:
# model.reset_parameters()
with torch.device(args.device):
model = create_model(args)
if not args.fsdp:
model.reset_parameters()
Expand Down
54 changes: 52 additions & 2 deletions open_lm/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import json
import logging
import re
from copy import deepcopy
from pathlib import Path
Expand All @@ -19,6 +20,15 @@
from open_lm.positional_embedding.rotary import RotaryWithCast
from open_lm.positional_embedding.llama_rotary import LLaMARotaryWithCast


# from open_lm.moe.mixture_of_experts import MoE
try:
from megablocks.layers.moe import MoE
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
logging.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have some assert here to make sure they're not using MoE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't check args during imports though, right?



try: # optional import
from mamba_ssm import MambaLMHeadModel
except ImportError:
Expand Down Expand Up @@ -77,6 +87,13 @@ class Params:
weight_tying: bool = False
norm_type: nn.Module = nn.LayerNorm
apply_qk_norm: bool = False
moe_loss_weight: float = 0.1
moe_capacity_factor: float = 1.25
moe_expert_model_parallelism: bool = False
moe_weight_parallelism: bool = False
moe_num_experts: int = 8
moe_top_k: int = 2
moe_freq: int = 0
positional_embedding_type: str = "rotary"
ffn_type: str = "swiglu"

Expand Down Expand Up @@ -237,10 +254,10 @@ def __init__(self, layer_id, args: Params):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim

self.head_dim = args.dim // args.n_heads
self.attention = CustomAttn(layer_id, args)
self._ffn_type = args.ffn_type

if args.ffn_type == "swiglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
Expand All @@ -251,6 +268,21 @@ def __init__(self, layer_id, args: Params):
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2)
elif args.ffn_type == "moe":
moe_args = MoEArgs(
hidden_size=args.dim,
ffn_hidden_size=args.dim * 4,
moe_num_experts=args.moe_num_experts,
moe_weight_parallelism=args.moe_weight_parallelism,
moe_expert_model_parallelism=args.moe_expert_model_parallelism,
moe_top_k=args.moe_top_k,
moe_capacity_factor=args.moe_capacity_factor,
moe_loss_weight=args.moe_loss_weight,
device=torch.cuda.current_device(),
bf16=False,
fp16=False,
)
self.feed_forward = MoE(moe_args)

self.layer_id = layer_id
self.attention_norm = args.norm_type(
Expand Down Expand Up @@ -289,7 +321,11 @@ def forward(self, x, past_key_value=None, use_cache=False):
use_cache=use_cache,
)
h = x + h
out = h + self.feed_forward(self.ffn_norm(h))
if self._ffn_type == "moe":
ffn_out, _ = self.feed_forward(self.ffn_norm(h))
else:
ffn_out = self.feed_forward(self.ffn_norm(h))
out = h + ffn_out
return out, past_key_value


Expand All @@ -298,8 +334,10 @@ def __init__(self, params):
super().__init__()
# for convenience we often share param names with llama
self.params = params
self.dim = params.dim
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.moe_num_experts = params.moe_num_experts
self.seq_len = params.seq_len
self.post_embed_norm = (
params.norm_type(
Expand All @@ -314,7 +352,12 @@ def __init__(self, params):
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

self.layers = torch.nn.ModuleList()
ffn_type_ = params.ffn_type
for layer_id in range(params.n_layers):
if params.moe_freq > 0 and layer_id % params.moe_freq == 0:
params.ffn_type = "moe"
else:
params.ffn_type = ffn_type_
self.layers.append(Block(layer_id, params))

# get class for normalization layers
Expand Down Expand Up @@ -405,6 +448,13 @@ def create_params(args):
apply_qk_norm=cfg.get("qk_norm", args.qk_norm),
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type),
ffn_type=cfg.get("ffn_type", args.ffn_type),
moe_num_experts=cfg.get("moe_num_experts", args.moe_num_experts),
moe_loss_weight=cfg.get("moe_loss_weight", args.moe_loss_weight),
moe_expert_model_parallelism=cfg.get("moe_expert_model_parallelism", args.moe_expert_model_parallelism),
moe_weight_parallelism=cfg.get("moe_weight_parallelism", args.moe_weight_parallelism),
moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor),
moe_freq=cfg.get("moe_freq", args.moe_freq),
moe_top_k=cfg.get("moe_top_k", args.moe_top_k),
)


Expand Down
44 changes: 44 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,50 @@ def add_model_args(parser):
default="rotary",
help="Type of positional embedding to use. This might be overridden by the model config.",
)
parser.add_argument(
"--moe-freq",
type=int,
default=0,
help="if set > 0, we will add MoE layer to every moe_freq layer.",
)
parser.add_argument(
"--moe-num-experts",
type=int,
default=None,
help="Number of experts for MoE",
)

parser.add_argument(
"--moe-weight-parallelism",
action="store_true",
help="Add weight parallelism to MoE",
)

parser.add_argument(
"--moe-expert-model-parallelism",
action="store_true",
help="Add expert model parallelism to MoE",
)

parser.add_argument(
"--moe-capacity-factor",
type=float,
default=1.25,
help="MoE capacity factor",
)

parser.add_argument(
"--moe-loss-weight",
type=float,
default=0.1,
help="MoE loss weight",
)
parser.add_argument(
"--moe-top-k",
type=int,
default=2,
help="MoE top k experts",
)


def check_replacement_type(replacement, original):
Expand Down
46 changes: 43 additions & 3 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP


try:
from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
logging.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again maybe some assert would be good?


try:
import wandb
except ImportError:
Expand Down Expand Up @@ -155,6 +162,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))

losses_m = AverageMeter()
load_balancing_losses_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()

Expand Down Expand Up @@ -228,11 +236,37 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
* inputs_ii.shape[0]
/ inputs.shape[0]
)
if args.moe_freq > 0:
moe_args = MoEArgs(
hidden_size=model.dim,
ffn_hidden_size=model.dim * 4,
moe_num_experts=args.moe_num_experts,
num_layers=model.n_layers // 2,
moe_expert_model_parallelism=True,
moe_top_k=args.moe_top_k,
device=torch.distributed.get_rank(),
moe_capacity_factor=args.moe_capacity_factor,
moe_loss_weight=args.moe_loss_weight,
fp16=False,
)
local_load_balancing_loss = batched_load_balancing_loss(moe_args)
clear_load_balancing_loss()

local_loss = local_load_balancing_loss + local_loss

backward(local_loss, scaler)
if ii == 0:
total_loss = local_loss
if args.moe_freq > 0:
total_loss = local_loss - local_load_balancing_loss
total_load_balancing_loss = local_load_balancing_loss
else:
total_loss = local_loss
else:
total_loss += local_loss
if args.moe_freq > 0:
total_loss += local_loss - local_load_balancing_loss
total_load_balancing_loss += local_load_balancing_loss
else:
total_loss += local_loss

if scaler is not None:
if args.grad_clip_norm is not None:
Expand Down Expand Up @@ -268,13 +302,18 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler

# gathered_loss = [torch.zeros_like(total_loss) for _ in range(args.world_size)]
# torch.distributed.all_gather(gathered_loss, total_loss)

# losses_m.update(sum(gathered_loss).item() / args.world_size, batch_size * args.world_size)
losses_m.update(global_loss_tensor.item(), batch_size)
if args.moe_freq > 0:
load_balancing_losses_m.update(total_load_balancing_loss.item(), batch_size)
samples_per_second = inputs.numel() * args.world_size / batch_time_m.val
samples_per_second_per_gpu = inputs.numel() / batch_time_m.val
loss_str = f"Loss: {losses_m.avg:.3f}"
loss_str += f" LB-Loss: {load_balancing_losses_m.avg:.3f}" if args.moe_freq > 0 else ""
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {losses_m.avg:.3f} "
f"{loss_str} "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
Expand All @@ -283,6 +322,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": losses_m.val,
"load_balancing_loss": load_balancing_losses_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"samples_per_second": samples_per_second,
Expand Down
1 change: 0 additions & 1 deletion scripts/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def main():

add_model_args(parser)
args = parser.parse_args()

print("Loading model into the right classes...")
open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args)))

Expand Down
7 changes: 7 additions & 0 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def __init__(self, model, **kwargs):
self.target_mask_left = None
self.target_mask_individual = None
self.ignore_parse_errors = False
self.moe_num_experts = None
self.moe_freq = 0
self.moe_weight_parallelism = False
self.moe_expert_model_parallelism = False
self.moe_capacity_factor = 1.25
self.moe_loss_weight = 0.1
self.moe_top_k = 2

for k, v in kwargs.items():
if hasattr(self, k):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_generate_kv_cache_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def test_generate_kv_cache(wiki_page, context_len, max_gen_len):
"qk_norm": False,
"positional_embedding_type": "rotary",
"ffn_type": "swiglu",
"moe_num_experts": None,
"moe_freq": 0,
"moe_weight_parallelism": False,
"moe_expert_model_parallelism": False,
"moe_capacity_factor": 1.25,
"moe_loss_weight": 0.1,
"moe_top_k": 2,
}
)

Expand Down
Loading
Loading