Skip to content

Commit

Permalink
added tiny beam generation test and fixed cache reorder (thanks Rui)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Dec 13, 2023
1 parent ca83264 commit 28ce58b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
6 changes: 3 additions & 3 deletions open_lm/utils/transformers/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def prepare_inputs_for_generation(

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
reordered_cache = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(1, beam_idx) for past_state in layer_past),)
return reordered_past
reordered_cache += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_cache


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions tests/test_tiny_generate_kv_cache_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from .utils import run_model, CharacterTokenizer
from tests.utils import run_model, CharacterTokenizer


# Download the checkpoint from HuggingFace Hub if it doesn't exist and set the args
@pytest.mark.gpu
Expand All @@ -21,6 +22,7 @@ def args():
"temperature": 0.0,
"top_p": 1.0,
"use_cache": False,
"num_beams": 1,
# Model params that might not be in config:
"model_norm": "default_layer_norm",
"qk_norm": False,
Expand Down Expand Up @@ -49,12 +51,15 @@ def tiny_tokenizer():
@pytest.mark.parametrize("wiki_page", ["Soil steam sterilization", "The Triumph of Death"])
@pytest.mark.parametrize("context_len", [4, 8])
@pytest.mark.parametrize("max_gen_len", [4, 8])
def test_tiny_generate_kv_cache(tiny_open_lm, tiny_tokenizer, args, wiki_page, context_len, max_gen_len):
@pytest.mark.parametrize("num_beams", [1, 4])
def test_tiny_generate_kv_cache(tiny_open_lm, tiny_tokenizer, args, wiki_page, context_len, max_gen_len, num_beams):
"""
This test checks that the results of the generation are the same with and without cache.
"""
args.max_gen_len = max_gen_len
args.context_len = context_len
args.num_beams = num_beams

if max_gen_len + context_len > tiny_open_lm.model.seq_len:
pytest.skip("The model cannot generate sequences that long")

Expand Down
50 changes: 31 additions & 19 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers import (
LogitsProcessorList,
MinLengthLogitsProcessor,
BeamSearchScorer,
StoppingCriteriaList,
MaxLengthCriteria,
)
import wikipedia
from composer.utils import dist, get_device
from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM
Expand Down Expand Up @@ -78,8 +85,7 @@ def download_val_data(name: str, root: str = None):


def download_dl_test_data(root: str = "./tests/assets"):
"""Downloads test files if the data doesn't exist in HF cache.
"""
"""Downloads test files if the data doesn't exist in HF cache."""

snapshot_args = dict(
repo_id="mlfoundations/open_lm_test_data_v2",
Expand All @@ -91,47 +97,46 @@ def download_dl_test_data(root: str = "./tests/assets"):


def make_tar(tar_num, num_lines, source_num=0, dir_name=None):
fname = lambda i: '%08d_chunk_%s.json' % (tar_num, i)
fname = lambda i: "%08d_chunk_%s.json" % (tar_num, i)

if dir_name != None:
Path(dir_name).mkdir(parents=True, exist_ok=True)
tarname = os.path.join(dir_name, '%08d.tar' % tar_num)

tarname = os.path.join(dir_name, "%08d.tar" % tar_num)
if os.path.exists(tarname):
return

fnames = []
with tarfile.open(tarname, 'w') as tar:
with tarfile.open(tarname, "w") as tar:
for line in range(num_lines):
base_line = [666 for _ in range(2049)]
base_line[0] = source_num
base_line[1] = tar_num
base_line[2] = line
this_file = fname(line)
with open(this_file, 'w') as f:
with open(this_file, "w") as f:
f.write(json.dumps(base_line))
tar.add(this_file)
fnames.append(this_file)



for f in fnames:
try:
os.unlink(f)
except:
pass


def make_source(source_num, size_per_tar, total_size):
num_tars = total_size // size_per_tar
if total_size % size_per_tar != 0:
num_tars += 1

base_dir = "tests/assets"
os.makedirs(base_dir, exist_ok=True)

num_remaining = total_size
num_remaining = total_size
for tar_num in range(num_tars):
this_tar = min(num_remaining, size_per_tar)
this_tar = min(num_remaining, size_per_tar)
make_tar(tar_num, this_tar, source_num=source_num, dir_name="tests/assets/source_id_%02d" % source_num)
num_remaining -= this_tar

Expand All @@ -140,7 +145,7 @@ def make_source(source_num, size_per_tar, total_size):


def make_fake_tarfiles():
""" Makes sources for dataloader tests.
"""Makes sources for dataloader tests.
Running main will...
- generate 2 sources, titled 'source_id_00', 'source_id_01'
- each source has 7 .tar files, each with 100 sequences (except the last which has 66)
Expand All @@ -152,10 +157,7 @@ def make_fake_tarfiles():
make_source(i, 100, 666)


@torch.inference_mode()
def run_model(open_lm, tokenizer, args, wiki_page=None, start_index=None):

dist.initialize_dist(get_device(None), timeout=600)
def _get_tokens_inputs(tokenizer, args, wiki_page=None, start_index=None):
if args.input_text == "random":
wikipedia.set_lang("en")
try:
Expand Down Expand Up @@ -190,6 +192,13 @@ def run_model(open_lm, tokenizer, args, wiki_page=None, start_index=None):
input = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in input.items()}
else:
input = {k: torch.tensor(v).unsqueeze(0) for k, v in input.items()}
return input


@torch.inference_mode()
def run_model(open_lm, tokenizer, args, wiki_page=None, start_index=None):
dist.initialize_dist(get_device(None), timeout=600)
input = _get_tokens_inputs(tokenizer, args, wiki_page=wiki_page, start_index=start_index)
composer_model = SimpleComposerOpenLMCausalLM(open_lm, tokenizer)
if torch.cuda.is_available():
composer_model = composer_model.cuda()
Expand All @@ -200,6 +209,9 @@ def run_model(open_lm, tokenizer, args, wiki_page=None, start_index=None):
"max_new_tokens": args.max_gen_len,
"use_cache": args.use_cache,
}

if args.num_beams > 1:
generate_args["num_beams"] = args.num_beams
# If these are set when temperature is 0, they will trigger a warning and be ignored
if args.temperature > 0:
generate_args["temperature"] = args.temperature
Expand Down

0 comments on commit 28ce58b

Please sign in to comment.