-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgather.py
148 lines (117 loc) · 5.1 KB
/
gather.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
import fire
import os
import tqdm
import random
import time
import multiprocessing
TOKENIZED_DATASETS_DIR = '/lustre/scratch/shared-folders/llm_project/bowen.tan/tokenized_datasets'
TOKENIZED_FIM_DATASETS_DIR = '/lustre/scratch/shared-folders/llm_project/bowen.tan/fim_datasets'
OUTPUT_DIR = '/lustre/scratch/shared-folders/llm_project/bowen.tan/final_data_chunks'
# TOKENIZED_DATASETS_DIR = '/mount/data/s3'
# TOKENIZED_FIM_DATASETS_DIR = '/mount/data/fim'
# OUTPUT_DIR = '/mount/data/data_chunks'
N_CHUNKS = 360
CHUNK_BATCH_SIZE = 16
FIM_RATE = 0.5
SPM_RATE = 0.5
def distribute_samples_starcoder(subset_name,
filename,
output_file,
file_idx_to_write,
tgt_file_idx,
n_repeat):
examples_path = f'{TOKENIZED_DATASETS_DIR}/{subset_name}/{filename}'
fim_path = f'{TOKENIZED_FIM_DATASETS_DIR}/{subset_name}.spm0.0/{filename}'
spm_path = f'{TOKENIZED_FIM_DATASETS_DIR}/{subset_name}.spm1.0/{filename}'
pseudo_n_chunks = int(N_CHUNKS / n_repeat)
file_idx_to_write = file_idx_to_write % pseudo_n_chunks
random.seed(time.time() + tgt_file_idx)
for line_idx, (line, line_fim, line_spm) in tqdm.tqdm(
enumerate(zip(open(examples_path), open(fim_path), open(spm_path))),
desc=f'{subset_name}/{filename} (pseudo_chunks={pseudo_n_chunks})'):
if file_idx_to_write == tgt_file_idx % pseudo_n_chunks:
src_filename = examples_path
if random.random() < FIM_RATE:
line = line_fim
src_filename = fim_path
if random.random() < SPM_RATE:
line = line_spm
src_filename = spm_path
example = json.loads(line)
example['subset_name'] = subset_name
example['src_filename'] = src_filename
example['line_idx_in_src'] = line_idx
output_file.write(json.dumps(example) + '\n')
output_file.flush()
file_idx_to_write = (file_idx_to_write + 1) % pseudo_n_chunks
return file_idx_to_write
def distribute_samples(subset_name,
filename,
output_file,
file_idx_to_write,
tgt_file_idx,
n_repeat):
examples_path = f'{TOKENIZED_DATASETS_DIR}/{subset_name}/{filename}'
pseudo_n_chunks = int(N_CHUNKS / n_repeat)
file_idx_to_write = file_idx_to_write % pseudo_n_chunks
for line_idx, line in tqdm.tqdm(
enumerate(open(examples_path)),
desc=f'{subset_name}/{filename} (pseudo_chunks={pseudo_n_chunks})'):
if file_idx_to_write == tgt_file_idx % pseudo_n_chunks:
example = json.loads(line)
example['subset_name'] = subset_name
example['src_filename'] = examples_path
example['line_idx_in_src'] = line_idx
output_file.write(json.dumps(example) + '\n')
output_file.flush()
file_idx_to_write = (file_idx_to_write + 1) % pseudo_n_chunks
return file_idx_to_write
def gather_chunk(tgt_file_idx):
output_file = open(f'{OUTPUT_DIR}/chunk_{tgt_file_idx}.jsonl', 'w')
file_idx_to_write = 0
for subset_name in sorted(os.listdir(TOKENIZED_DATASETS_DIR)):
if subset_name == 'redpajama.c4':
continue
elif subset_name in ['dm-math', 'pubmed-abstracts', 'uspto']:
n_repeat = 3
elif subset_name == 'redpajama.wikipedia':
n_repeat = 6
elif subset_name == 'redpajama.arxiv':
n_repeat = 1
elif subset_name.startswith('redpajama.'):
n_repeat = 3
elif subset_name.startswith('starcoder.'):
n_repeat = 0.5
else:
n_repeat = 1
for filename in sorted(os.listdir(
f'{TOKENIZED_DATASETS_DIR}/{subset_name}')):
assert filename.endswith(f'.jsonl')
if subset_name.startswith('starcoder.'):
file_idx_to_write = distribute_samples_starcoder(
subset_name=subset_name,
filename=filename,
output_file=output_file,
file_idx_to_write=file_idx_to_write,
tgt_file_idx=tgt_file_idx,
n_repeat=n_repeat)
else:
file_idx_to_write = distribute_samples(
subset_name=subset_name,
filename=filename,
output_file=output_file,
file_idx_to_write=file_idx_to_write,
tgt_file_idx=tgt_file_idx,
n_repeat=n_repeat)
return tgt_file_idx
def main():
os.makedirs(OUTPUT_DIR, exist_ok=True)
for l in tqdm.trange(
0, N_CHUNKS, CHUNK_BATCH_SIZE, desc='Processing chunks'):
with multiprocessing.Pool(processes=CHUNK_BATCH_SIZE) as pool:
results = pool.map(
gather_chunk, list(range(l, l + CHUNK_BATCH_SIZE)))
print(results)
if __name__ == '__main__':
fire.Fire(main)