Skip to content

Commit

Permalink
compatible datasets version and unified python versions to 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Dec 9, 2023
1 parent 8ffeba8 commit 1285146
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
Expand Down
2 changes: 1 addition & 1 deletion environment-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: open_lm_tests
channels:
- defaults
dependencies:
- python=3.8
- python=3.10
- pip
- pip:
- -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: open_lm
channels:
- defaults
dependencies:
- python=3.8
- python=3.10
- pip
- pip:
- -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion open_lm/model_configs/open_lm_1b.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
}
}
2 changes: 1 addition & 1 deletion open_lm/model_configs/open_lm_7b.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
}
}
8 changes: 5 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ jsonlines
boto3==1.26.90
Pillow
zstandard
llm-foundry
pysimdjson
cloudpathlib
apache_beam
datasets~=2.5.2
datasets
multiprocess>=0.70.11
dill
llm-foundry
huggingface_hub
pre-commit
ray[all]
Expand All @@ -23,4 +24,5 @@ jsonlines
transformers
s3fs
wikipedia
ipython

1 change: 0 additions & 1 deletion tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __init__(self):


def create_train_fixtures(model="open_lm_11m", fsdp=False):

# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataset_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shards_needed=2):
args = parse_args("")

train_data_string_per_source, num_samples_per_source, _ = get_string_for_epoch(
NUM_SAMPLES, [next_shard, next_shard], INPUT_PATHS, weights, min_shards_needed, world_size=1
)
Expand All @@ -46,7 +46,7 @@ def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shard
args.rank = 0
data = get_wds_dataset(args, is_train=True, epoch=epoch, force_num_samples=num_samples_per_source)
dl = data.dataloader

return dl


Expand All @@ -67,7 +67,7 @@ def retrieve_dataset_resampled(epoch, next_shard, weights, seed, min_shards_need
args.rank = 0
data = get_wds_dataset(args, is_train=True, epoch=epoch)
dl = data.dataloader

return dl


Expand Down
8 changes: 7 additions & 1 deletion tests/test_dataset_no_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def retrieve_dataset_once(
weights=weights,
num_workers_per_gpu=num_workers,
world_size=1,
multi_epoch=False,
multi_epoch=False
)

args.train_num_samples = total_seqs
Expand All @@ -96,6 +96,7 @@ def retrieve_dataset_once(
return data



# ======================================================
# = Single Source Test Cases =
# ======================================================
Expand Down Expand Up @@ -173,6 +174,7 @@ def test_singleSource_singleWorker_imperfectBatch(num_samples, next_shard, batch
assert len(data_ids) == batch_size * (num_samples // batch_size) #



def test_singleSource_multiWorker_0():
"""
Asking for 200 samples with 2 workers and a batchsize of 10.
Expand Down Expand Up @@ -332,6 +334,7 @@ def test_singleSource_multiWorker_3():
assert sorted(target_data_ids) == sorted(data_ids)



def test_singleSource_multiWorker_4():
"""
Asking for 256 samples from 2 workers
Expand Down Expand Up @@ -369,3 +372,6 @@ def test_singleSource_multiWorker_4():
data_ids.append(tuple(seq[:3]))
target_data_ids = [(0, i, j) for i in range(2) for j in range(100)] # all of shard 000, 001
assert sorted(target_data_ids) == sorted(data_ids)



105 changes: 52 additions & 53 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from open_lm.file_utils import get_string_for_epoch

import pytest
import pytest
import os
import math
import json

from tests.utils import download_dl_test_data, make_fake_tarfiles

Expand All @@ -21,40 +22,45 @@
SINGLE_SOURCE = ["tests/assets/source_id_00/manifest.jsonl"]
# ^ 6 files with 100 sequences, 1 file with 66 sequences


@pytest.mark.parametrize("num_samples,starting_point", [(10, 0), (10, 1), (100, 2)])
@pytest.mark.parametrize(
"num_samples,starting_point",
[(10, 0), (10, 1), (100, 2)]
)
def test_gsfe_ss_0(num_samples, starting_point):
"""Test case when we want to consume exactly one file, with a single worker"""
""" Test case when we want to consume exactly one file, with a single worker """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
num_samples, [starting_point], SINGLE_SOURCE, None, num_workers_per_gpu=1, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(num_samples, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=1, world_size=1, multi_epoch=False)

assert shards_ps == [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "%08d.tar" % starting_point)]
assert shards_ps == [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '%08d.tar' % starting_point)]
assert nums_ps == ([100] if starting_point < 6 else [66])
assert next_ps == [starting_point + 1]


@pytest.mark.parametrize("num_samples,starting_point", [(101, 0), (250, 1), (150, 5)])
@pytest.mark.parametrize(
"num_samples,starting_point",
[(101, 0), (250, 1), (150, 5)]
)
def test_gsfe_ss_1(num_samples, starting_point):
"""Test case when we want to consume multiple files, with a single worker"""
""" Test case when we want to consume multiple files, with a single worker """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
num_samples, [starting_point], SINGLE_SOURCE, None, num_workers_per_gpu=1, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(num_samples, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=1, world_size=1, multi_epoch=False)


expected_num_shards = math.ceil(num_samples / 100.0)
expected_shardlist = ["%08d" % i for i in range(starting_point, starting_point + expected_num_shards)]
expected_shardlist = ['%08d' % i for i in range(starting_point, starting_point + expected_num_shards)]


expected_num_samples = expected_num_shards * 100
if expected_shardlist[-1] == "%08d" % 6:
if expected_shardlist[-1] == '%08d' % 6:
expected_num_samples -= 34

expected_shard_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(expected_shardlist))]
expected_shard_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(expected_shardlist))]
expected_nums_ps = [expected_num_samples]
expected_next_ps = [starting_point + expected_num_shards]

Expand All @@ -63,52 +69,49 @@ def test_gsfe_ss_1(num_samples, starting_point):
assert expected_next_ps == next_ps


@pytest.mark.parametrize("num_samples,starting_point", [(1000, 0), (200, 5)])

@pytest.mark.parametrize(
"num_samples,starting_point",
[(1000, 0), (200, 5)]
)
def test_gsfe_ss_2(num_samples, starting_point):
"""Test case when we want to consume too many samples, with a single worker"""
""" Test case when we want to consume too many samples, with a single worker """
download_dl_test_data()
make_fake_tarfiles()

try:
get_string_for_epoch(
num_samples, [starting_point], SINGLE_SOURCE, None, num_workers_per_gpu=1, world_size=1, multi_epoch=False
)
get_string_for_epoch(num_samples, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=1, world_size=1, multi_epoch=False)
except IndexError:
assert True


@pytest.mark.parametrize("num_workers,world_size,starting_point", [(10, 1, 0), (5, 2, 0), (3, 3, 0), (3, 1, 5)])
@pytest.mark.parametrize(
"num_workers,world_size,starting_point",
[(10, 1, 0), (5, 2, 0), (3, 3, 0), (3, 1, 5)]
)
def test_gsfe_ss_3(num_workers, world_size, starting_point):
"""Test case when we want to consume data but have too many workers"""
""" Test case when we want to consume data but have too many workers """
download_dl_test_data()
make_fake_tarfiles()

try:
get_string_for_epoch(
42,
[starting_point],
SINGLE_SOURCE,
None,
num_workers_per_gpu=num_workers,
world_size=world_size,
multi_epoch=False,
)
get_string_for_epoch(42, [starting_point], SINGLE_SOURCE, None,
num_workers_per_gpu=num_workers,
world_size=world_size, multi_epoch=False)
except IndexError:
assert True


def test_gsfe_ss_4():
"""Test case when we want to consume a small amount of data, but with multiple workers"""
""" Test case when we want to consume a small amount of data, but with multiple workers """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
10, [0], SINGLE_SOURCE, None, num_workers_per_gpu=2, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(10, [0], SINGLE_SOURCE, None,
num_workers_per_gpu=2, world_size=1, multi_epoch=False)

expected_shards_ps = [
os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(["%08d" % i for i in range(2)]))
]
expected_shards_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(['%08d' % i for i in range(2)]))]
expected_nums_ps = [200]
expected_next_ps = [2]

Expand All @@ -118,17 +121,14 @@ def test_gsfe_ss_4():


def test_gsfe_ss_5():
"""Test case whne we want to consume a reasonable amount of data, multiple workers"""
""" Test case whne we want to consume a reasonable amount of data, multiple workers """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
400, [0], SINGLE_SOURCE, None, num_workers_per_gpu=2, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(400, [0], SINGLE_SOURCE, None,
num_workers_per_gpu=2, world_size=1, multi_epoch=False)

expected_shards_ps = [
os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(["%08d" % i for i in range(4)]))
]
expected_shards_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(['%08d' % i for i in range(4)]))]
expected_nums_ps = [400]
expected_next_ps = [4]

Expand All @@ -137,21 +137,20 @@ def test_gsfe_ss_5():
assert expected_next_ps == next_ps



def test_gsfe_ss_6():
"""Test case when we want to reasonable data, but uneven modulo num workers"""
""" Test case when we want to reasonable data, but uneven modulo num workers """
download_dl_test_data()
make_fake_tarfiles()

shards_ps, nums_ps, next_ps = get_string_for_epoch(
450, [0], SINGLE_SOURCE, None, num_workers_per_gpu=2, world_size=1, multi_epoch=False
)
shards_ps, nums_ps, next_ps = get_string_for_epoch(450, [0], SINGLE_SOURCE, None,
num_workers_per_gpu=2, world_size=1, multi_epoch=False)

expected_shards_ps = [
os.path.join(os.path.dirname(SINGLE_SOURCE[0]), "{%s}.tar" % ",".join(["%08d" % i for i in range(4)]))
]
expected_shards_ps = [os.path.join(os.path.dirname(SINGLE_SOURCE[0]), '{%s}.tar' % ','.join(['%08d' % i for i in range(4)]))]
expected_nums_ps = [400]
expected_next_ps = [4]

assert shards_ps == expected_shards_ps
assert expected_nums_ps == nums_ps
assert expected_next_ps == next_ps

48 changes: 25 additions & 23 deletions tests/test_make_wds_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,40 @@
"""


@pytest.mark.parametrize("source_dir", ["source_1", "source_2"])
@pytest.mark.parametrize('source_dir', ['source_1', 'source_2'])
def test_make_manifest_from_source(source_dir):
download_dl_test_data("tests/assets")
download_dl_test_data('tests/assets')

MOCK_MANIFEST = "tests/assets/%s/mock_manifest.jsonl" % source_dir
if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)
MOCK_MANIFEST = 'tests/assets/%s/mock_manifest.jsonl' % source_dir
if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)

args = ["--data-dir", "tests/assets/%s" % source_dir, "--manifest-filename", "mock_manifest.jsonl"]
mwm.main(args)
args = ['--data-dir', 'tests/assets/%s' % source_dir, '--manifest-filename', 'mock_manifest.jsonl']
mwm.main(args)

true_manifest = "tests/assets/%s/manifest.jsonl" % source_dir
with open(true_manifest, "r") as true_file:
with open(MOCK_MANIFEST, "r") as mock_file:
assert true_file.read() == mock_file.read()

if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)
true_manifest = 'tests/assets/%s/manifest.jsonl' % source_dir
with open(true_manifest, 'r') as true_file:
with open(MOCK_MANIFEST, 'r') as mock_file:
assert true_file.read() == mock_file.read()

if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)


def test_make_toplevel_manifest():
download_dl_test_data("tests/assets")
download_dl_test_data('tests/assets')

MOCK_MANIFEST = 'tests/assets/mock_manifest.jsonl'
if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)

MOCK_MANIFEST = "tests/assets/mock_manifest.jsonl"
if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)
args = ['--data-dir', 'tests/assets/', '--manifest-filename', 'mock_manifest.jsonl']
mwm.main(args)

args = ["--data-dir", "tests/assets/", "--manifest-filename", "mock_manifest.jsonl"]
mwm.main(args)
lines = [json.loads(_) for _ in open(MOCK_MANIFEST, 'r').readlines()]
assert lines == [{'shard': 'shard_00000000', 'num_sequences': 120}]

lines = [json.loads(_) for _ in open(MOCK_MANIFEST, "r").readlines()]
assert lines == [{"shard": "shard_00000000", "num_sequences": 120}]

if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)
if os.path.exists(MOCK_MANIFEST):
os.unlink(MOCK_MANIFEST)
Loading

0 comments on commit 1285146

Please sign in to comment.