-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
122 lines (88 loc) · 4.47 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import linecache
import timeit
import numpy as np
import polars as pl
import sentencepiece as spm
from queue import Queue
import threading
class PairedTextDataset():
def __init__(self, source_file, target_file, source_tokenizer: spm.SentencePieceProcessor, target_tokenizer: spm.SentencePieceProcessor):
self.source_file = source_file
self.target_file = target_file
self.source_tokenizer = source_tokenizer
self.target_tokenizer = target_tokenizer
if not os.path.exists(source_file) or not os.path.exists(target_file):
raise FileNotFoundError("Source or target file not found")
self.source_lines = 0
with open(self.source_file, 'r', encoding='utf-8') as src_file:
for line in src_file:
self.source_lines += 1
self.target_lines = 0
with open(self.target_file, 'r', encoding='utf-8') as tgt_file:
for line in tgt_file:
self.target_lines += 1
if self.source_lines != self.target_lines:
raise ValueError("Source and target files must have the same number of lines")
def get_polars(self):
source_ds = pl.read_csv(self.source_file,
has_header=False,
separator="\n",
schema_overrides={"source": pl.Utf8},
quote_char=None)
target_ds = pl.read_csv(self.target_file,
has_header=False,
separator="\n",
schema_overrides={"target": pl.Utf8},
quote_char=None)
source_ds = source_ds.with_row_index()
target_ds = target_ds.with_row_index()
return source_ds.join(target_ds, on="index")
def batch_length(self, batch_size=32):
return np.ceil(self.source_lines / batch_size)
def _get_batch(self, indicies):
source_batch = []
target_input_batch = []
target_output_batch = []
max_source_length = 0
max_target_length = 0
for idx in indicies:
source_line = linecache.getline(self.source_file, idx+1).strip()
target_line = linecache.getline(self.target_file, idx+1).strip()
source_tokens = self.source_tokenizer.Encode(source_line, add_bos=True, add_eos=True)
target_tokens = self.target_tokenizer.Encode(target_line, add_bos=True, add_eos=True)
source_batch.append(source_tokens)
target_input_batch.append(target_tokens[:-1])
target_output_batch.append(target_tokens[1:])
max_source_length = max(max_source_length, len(source_tokens))
max_target_length = max(max_target_length, len(target_tokens) - 1)
for idx in range(len(source_batch)):
source_batch[idx] = torch.nn.functional.pad(torch.tensor(source_batch[idx]), (0, max_source_length - len(source_batch[idx])))
target_input_batch[idx] = torch.nn.functional.pad(torch.tensor(target_input_batch[idx]), (0, max_target_length - len(target_input_batch[idx])))
target_output_batch[idx] = torch.nn.functional.pad(torch.tensor(target_output_batch[idx]), (0, max_target_length - len(target_output_batch[idx])))
return torch.stack(source_batch), torch.stack(target_input_batch), torch.stack(target_output_batch)
def batch(self, batch_size=32):
indices = np.arange(self.source_lines)
np.random.shuffle(indices)
cache_size = 200
batch_queue = Queue(maxsize=cache_size)
def fill_cache():
local_indices = indices.copy()
while len(local_indices) > 0:
# print("Filling cache")
chunk = local_indices[:batch_size]
local_indices = local_indices[batch_size:]
batch_data = self._get_batch(chunk)
batch_queue.put(batch_data)
batch_queue.put(None)
threading.Thread(target=fill_cache, daemon=True).start()
while len(indices) > 0:
batch_indices = indices[:batch_size]
indices = indices[batch_size:]
the_batch = batch_queue.get()
if the_batch is None:
break
yield the_batch
if __name__ == "__main__":
pass