Skip to content

Commit

Permalink
Merge pull request #312 from instructlab/granite-dolomite
Browse files Browse the repository at this point in the history
Adding Dolomite Support and Bringing HF Padding-Free into Performance Parity
  • Loading branch information
mergify[bot] authored Nov 1, 2024
2 parents dc7c97d + 2a9626f commit 512949f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ numba
numpy>=1.23.5,<2.0.0 ; python_version == '3.10'
numpy>=1.26.4,<2.0.0 ; python_version != '3.10'
rich
instructlab-dolomite>=0.1.1
instructlab-dolomite>=0.2.0
trl>=0.9.4
peft
pydantic>=2.7.0
Expand Down
5 changes: 5 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy
from pathlib import Path
import argparse
import json
import math
import os
import re
Expand Down Expand Up @@ -528,6 +529,10 @@ def main(args):
tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
# device = torch.device("cuda", args.local_rank)

with open(Path(args.model_name_or_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
args.model_type = model_conf["model_type"]

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])
Expand Down
39 changes: 21 additions & 18 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, List, Optional
import importlib
import inspect
import json
import logging
import os
import random
Expand Down Expand Up @@ -62,17 +61,10 @@ def check_valid_train_args(train_args: TrainingArgs):
f"Provided path to model does not exist. Please make sure that you've passed a valid model and that it has appropriate permissions: {train_args.model_path}"
)

if train_args.use_dolomite:
with open(Path(train_args.model_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
if model_conf["model_type"] == "granite":
raise RuntimeError(
"Converting Granite models to Dolomite format is currently unsupported."
)
if train_args.disable_flash_attn:
raise RuntimeError(
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
)
if train_args.use_dolomite and train_args.disable_flash_attn:
raise RuntimeError(
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
)

if train_args.is_padding_free:
print(
Expand Down Expand Up @@ -229,7 +221,7 @@ def pad_collate_fn(batch):

input_ids.extend(item["input_ids"].tolist())
labels.extend(item["labels"].tolist())
position_ids.extend(range(total_len, total_len + item_len))
position_ids.extend(range(item_len))

total_len += item_len
num_loss_counted_tokens += (item["labels"] != -100).sum().item()
Expand Down Expand Up @@ -802,10 +794,21 @@ def _get_state_dict_patched(model, unwrap=False):

output_dir.mkdir(parents=True, exist_ok=True)
if not model.module.config.architectures and convert_dolomite:
model.module.config.architectures = ["LlamaForCausalLM"]
warnings.warn(
f"Adding architectures to ckpt: {model.module.config.architectures}",
)
arch_added = False
if args.model_type == "llama":
model.module.config.architectures = ["LlamaForCausalLM"]
arch_added = True
elif args.model_type == "granite":
model.module.config.architectures = ["GraniteForCausalLM"]
arch_added = True
if arch_added:
warnings.warn(
f"Adding architectures to ckpt: {model.module.config.architectures}",
)
else:
warnings.warn(
f"Converting from dolomite, but no architecture field added to config.json",
)
model.module.config.to_json_file(output_config_file)
tokenizer.save_pretrained(output_dir)

Expand Down Expand Up @@ -834,7 +837,7 @@ def _get_state_dict_patched(model, unwrap=False):
export_to_huggingface(
pretrained_model_name_or_path=tmpdir.name,
save_path=final_output_dir,
model_type="llama",
model_type=args.model_type,
)
tmpdir.cleanup()

Expand Down

0 comments on commit 512949f

Please sign in to comment.