Skip to content

Commit

Permalink
compatible datasets version and unified python versions to 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Dec 9, 2023
1 parent 8ffeba8 commit 753297a
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 108 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
Expand Down
2 changes: 1 addition & 1 deletion environment-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: open_lm_tests
channels:
- defaults
dependencies:
- python=3.8
- python=3.10
- pip
- pip:
- -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: open_lm
channels:
- defaults
dependencies:
- python=3.8
- python=3.10
- pip
- pip:
- -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)

Expand Down
2 changes: 1 addition & 1 deletion open_lm/model_configs/open_lm_1b.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
}
}
2 changes: 1 addition & 1 deletion open_lm/model_configs/open_lm_7b.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
}
}
8 changes: 5 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ jsonlines
boto3==1.26.90
Pillow
zstandard
llm-foundry
pysimdjson
cloudpathlib
apache_beam
datasets~=2.5.2
datasets
multiprocess>=0.70.11
dill
llm-foundry
huggingface_hub
pre-commit
ray[all]
Expand All @@ -23,4 +24,5 @@ jsonlines
transformers
s3fs
wikipedia
ipython

1 change: 0 additions & 1 deletion tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __init__(self):


def create_train_fixtures(model="open_lm_11m", fsdp=False):

# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataset_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shards_needed=2):
args = parse_args("")

train_data_string_per_source, num_samples_per_source, _ = get_string_for_epoch(
NUM_SAMPLES, [next_shard, next_shard], INPUT_PATHS, weights, min_shards_needed, world_size=1
)
Expand All @@ -46,7 +46,7 @@ def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shard
args.rank = 0
data = get_wds_dataset(args, is_train=True, epoch=epoch, force_num_samples=num_samples_per_source)
dl = data.dataloader

return dl


Expand All @@ -67,7 +67,7 @@ def retrieve_dataset_resampled(epoch, next_shard, weights, seed, min_shards_need
args.rank = 0
data = get_wds_dataset(args, is_train=True, epoch=epoch)
dl = data.dataloader

return dl


Expand Down
8 changes: 7 additions & 1 deletion tests/test_dataset_no_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def retrieve_dataset_once(
weights=weights,
num_workers_per_gpu=num_workers,
world_size=1,
multi_epoch=False,
multi_epoch=False
)

args.train_num_samples = total_seqs
Expand All @@ -96,6 +96,7 @@ def retrieve_dataset_once(
return data



# ======================================================
# = Single Source Test Cases =
# ======================================================
Expand Down Expand Up @@ -173,6 +174,7 @@ def test_singleSource_singleWorker_imperfectBatch(num_samples, next_shard, batch
assert len(data_ids) == batch_size * (num_samples // batch_size) #



def test_singleSource_multiWorker_0():
"""
Asking for 200 samples with 2 workers and a batchsize of 10.
Expand Down Expand Up @@ -332,6 +334,7 @@ def test_singleSource_multiWorker_3():
assert sorted(target_data_ids) == sorted(data_ids)



def test_singleSource_multiWorker_4():
"""
Asking for 256 samples from 2 workers
Expand Down Expand Up @@ -369,3 +372,6 @@ def test_singleSource_multiWorker_4():
data_ids.append(tuple(seq[:3]))
target_data_ids = [(0, i, j) for i in range(2) for j in range(100)] # all of shard 000, 001
assert sorted(target_data_ids) == sorted(data_ids)



105 changes: 52 additions & 53 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from open_lm.file_utils import get_string_for_epoch

import pytest
import pytest
import os
import math
import json

from tests.utils import download_dl_test_data, make_fake_tarfiles

Expand All @@ -21,40 +22,45 @@
SINGLE_SOURCE = ["tests/assets/source_id_00/manifest.jsonl"]
# ^ 6 files with 100 sequences, 1 file with 66 sequences


@pytest.mark.parametrize("num_samples,starting_point", [(10, 0), (10, 1), (100, 2)])
@pytest.mark.parametrize(
"num_samples,starting_point",
[(10, 0), (10, 1), (100, 2)]
)
def test_gsfe_ss_0(num_samples, starting_point):
"""Test case when we want to consume exactly one file, with a single worker"""
""" Test case when we want to consume exactly one file, with a single worker """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
num_samples, [starting_point], SINGLE_SOURCE, None, num_workers_per_gpu=1, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(num_samples, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=1, world_size=1, multi_epoch=False)

assert shards_ps == [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "%08d.tar" % starting_point)]
assert shards_ps == [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '%08d.tar' % starting_point)]
assert nums_ps == ([100] if starting_point < 6 else [66])
assert next_ps == [starting_point + 1]


@pytest.mark.parametrize("num_samples,starting_point", [(101, 0), (250, 1), (150, 5)])
@pytest.mark.parametrize(
"num_samples,starting_point",
[(101, 0), (250, 1), (150, 5)]
)
def test_gsfe_ss_1(num_samples, starting_point):
"""Test case when we want to consume multiple files, with a single worker"""
""" Test case when we want to consume multiple files, with a single worker """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
num_samples, [starting_point], SINGLE_SOURCE, None, num_workers_per_gpu=1, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(num_samples, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=1, world_size=1, multi_epoch=False)


expected_num_shards = math.ceil(num_samples / 100.0)
expected_shardlist = ["%08d" % i for i in range(starting_point, starting_point + expected_num_shards)]
expected_shardlist = ['%08d' % i for i in range(starting_point, starting_point + expected_num_shards)]


expected_num_samples = expected_num_shards * 100
if expected_shardlist[-1] == "%08d" % 6:
if expected_shardlist[-1] == '%08d' % 6:
expected_num_samples -= 34

expected_shard_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(expected_shardlist))]
expected_shard_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(expected_shardlist))]
expected_nums_ps = [expected_num_samples]
expected_next_ps = [starting_point + expected_num_shards]

Expand All @@ -63,52 +69,49 @@ def test_gsfe_ss_1(num_samples, starting_point):
assert expected_next_ps == next_ps


@pytest.mark.parametrize("num_samples,starting_point", [(1000, 0), (200, 5)])

@pytest.mark.parametrize(
"num_samples,starting_point",
[(1000, 0), (200, 5)]
)
def test_gsfe_ss_2(num_samples, starting_point):
"""Test case when we want to consume too many samples, with a single worker"""
""" Test case when we want to consume too many samples, with a single worker """
download_dl_test_data()
make_fake_tarfiles()

try:
get_string_for_epoch(
num_samples, [starting_point], SINGLE_SOURCE, None, num_workers_per_gpu=1, world_size=1, multi_epoch=False
)
get_string_for_epoch(num_samples, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=1, world_size=1, multi_epoch=False)
except IndexError:
assert True


@pytest.mark.parametrize("num_workers,world_size,starting_point", [(10, 1, 0), (5, 2, 0), (3, 3, 0), (3, 1, 5)])
@pytest.mark.parametrize(
"num_workers,world_size,starting_point",
[(10, 1, 0), (5, 2, 0), (3, 3, 0), (3, 1, 5)]
)
def test_gsfe_ss_3(num_workers, world_size, starting_point):
"""Test case when we want to consume data but have too many workers"""
""" Test case when we want to consume data but have too many workers """
download_dl_test_data()
make_fake_tarfiles()

try:
get_string_for_epoch(
42,
[starting_point],
SINGLE_SOURCE,
None,
num_workers_per_gpu=num_workers,
world_size=world_size,
multi_epoch=False,
)
get_string_for_epoch(42, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=num_workers,
world_size=world_size, multi_epoch=False)
except IndexError:
assert True


def test_gsfe_ss_4():
"""Test case when we want to consume a small amount of data, but with multiple workers"""
""" Test case when we want to consume a small amount of data, but with multiple workers """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
10, [0], SINGLE_SOURCE, None, num_workers_per_gpu=2, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(10, [0], SINGLE_SOURCE, None,
num_workers_per_gpu=2, world_size=1, multi_epoch=False)

expected_shards_ps = [
os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(["%08d" % i for i in range(2)]))
]
expected_shards_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(['%08d' % i for i in range(2)]))]
expected_nums_ps = [200]
expected_next_ps = [2]

Expand All @@ -118,17 +121,14 @@ def test_gsfe_ss_4():


def test_gsfe_ss_5():
"""Test case whne we want to consume a reasonable amount of data, multiple workers"""
""" Test case whne we want to consume a reasonable amount of data, multiple workers """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
400, [0], SINGLE_SOURCE, None, num_workers_per_gpu=2, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(400, [0], SINGLE_SOURCE, None,
num_workers_per_gpu=2, world_size=1, multi_epoch=False)

expected_shards_ps = [
os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(["%08d" % i for i in range(4)]))
]
expected_shards_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(['%08d' % i for i in range(4)]))]
expected_nums_ps = [400]
expected_next_ps = [4]

Expand All @@ -137,21 +137,20 @@ def test_gsfe_ss_5():
assert expected_next_ps == next_ps



def test_gsfe_ss_6():
"""Test case when we want to reasonable data, but uneven modulo num workers"""
""" Test case when we want to reasonable data, but uneven modulo num workers """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
450, [0], SINGLE_SOURCE, None, num_workers_per_gpu=2, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(450, [0], SINGLE_SOURCE, None,
num_workers_per_gpu=2, world_size=1, multi_epoch=False)

expected_shards_ps = [
os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(["%08d" % i for i in range(4)]))
]
expected_shards_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(['%08d' % i for i in range(4)]))]
expected_nums_ps = [400]
expected_next_ps = [4]

assert shards_ps == expected_shards_ps
assert expected_nums_ps == nums_ps
assert expected_next_ps == next_ps

Loading

0 comments on commit 753297a

Please sign in to comment.