forked from extreme-bert/extreme-bert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_shards.py
91 lines (81 loc) · 3.05 KB
/
merge_shards.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# coding=utf-8
# Copyright 2022 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
# code taken from commit: ea000838156e3be251699ad6a3c8b1339c76e987
# https://github.com/IntelLabs/academic-budget-bert
# Copyright 2021 Intel Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from tqdm import tqdm
logging.basicConfig(
format="%(asctime)s - %(levelname)s: %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def write_shard(lines, f_idx, out_dir_path, name=None):
os.makedirs(out_dir_path, exist_ok=True)
filename = f"shard_{f_idx}.txt"
if name is not None and len(name) > 0:
filename = f"{name}_{filename}"
with open(os.path.join(out_dir_path, filename), "w") as fw:
for l in lines:
fw.write(l)
def list_files_in_dir(dir, data_prefix=".txt", file_name_grep=""):
dataset_files = [
os.path.join(dir, f)
for f in os.listdir(dir)
if os.path.isfile(os.path.join(dir, f)) and data_prefix in f and file_name_grep in f
]
return dataset_files
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", help="input directory with sharded text files", required=True)
parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
parser.add_argument(
"--ratio",
type=int,
default=1,
help="Number of files to merge into a single shard",
)
parser.add_argument(
"--grep",
type=str,
default="",
help="A string to filter a subset of files from input directory",
)
args = parser.parse_args()
dataset_files = list_files_in_dir(args.data, file_name_grep=args.grep)
num_files = len(dataset_files)
assert (
num_files % args.ratio == 0
), f"{num_files} % {args.ratio} != 0, make sure equal shards in each merged file"
os.makedirs(args.output_dir, exist_ok=True)
logger.info(f"Compacting input shards into {args.output_dir}")
# merge input directory shards into num_files_shard
file_lines = []
f_idx = 0
lines_idx = 0
for f in tqdm(dataset_files, smoothing=1):
with open(f) as fp:
file_lines.extend(fp.readlines())
if lines_idx == args.ratio - 1:
write_shard(file_lines, f_idx, args.output_dir, name=args.grep)
file_lines = []
f_idx += 1
lines_idx = 0
continue
lines_idx += 1
logger.info("Done!")