forked from sapientinc/HRM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpuzzle_dataset.py
More file actions
199 lines (150 loc) · 7.79 KB
/
puzzle_dataset.py
File metadata and controls
199 lines (150 loc) · 7.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import json
import numpy as np
import pydantic
import torch
from torch.utils.data import IterableDataset, get_worker_info
from models.losses import IGNORE_LABEL_ID
from dataset.common import PuzzleDatasetMetadata
def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
# Pack examples into a full batch
batch = []
batch_puzzle_indices = []
current_size = 0
while (start_index < group_order.size) and (current_size < global_batch_size):
# Pick a group and a puzzle from that group
group_id = group_order[start_index]
puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
start_index += 1
# Get range of the puzzle
puzzle_start = puzzle_indices[puzzle_id]
puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
append_size = min(puzzle_size, global_batch_size - current_size)
# Put into batch
batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
current_size += append_size
return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
class PuzzleDatasetConfig(pydantic.BaseModel):
seed: int
dataset_path: str
global_batch_size: int
test_set_mode: bool
epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
rank: int
num_replicas: int
class PuzzleDataset(IterableDataset):
def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
super().__init__()
self.config = config
self.split = split
self.metadata = self._load_metadata()
# Checks
assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
# State
self._data = None
self._iters = 0
def _load_metadata(self) -> PuzzleDatasetMetadata:
with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
return PuzzleDatasetMetadata(**json.load(f))
def _lazy_load_dataset(self):
if self._data is not None:
return
field_mmap_modes = {
"inputs": "r",
"labels": "r",
# Keep indices in memory
"puzzle_identifiers": None,
"puzzle_indices": None,
"group_indices": None
}
# Load data
self._data = {}
for set_name in self.metadata.sets:
# Load subset
self._data[set_name] = {
field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
for field_name, mmap_mode in field_mmap_modes.items()
}
def _collate_batch(self, batch):
# Convert dtype
batch = {k: v.astype(np.int32) for k, v in batch.items()}
# Convert ignore label IDs
if self.metadata.ignore_label_id is not None:
batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
# Pad
if batch["puzzle_identifiers"].size < self.local_batch_size:
pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
pad_values = {
"inputs": self.metadata.pad_id,
"labels": IGNORE_LABEL_ID,
"puzzle_identifiers": self.metadata.blank_identifier_id
}
batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
# To tensor
return {k: torch.from_numpy(v) for k, v in batch.items()}
def _iter_test(self):
for set_name, dataset in self._data.items(): # type: ignore
total_examples = len(dataset["inputs"])
# Load examples one by one
start_index = 0
while start_index < total_examples:
# Compute indices
end_index = min(total_examples, start_index + self.config.global_batch_size)
local_start = start_index + self.config.rank * self.local_batch_size
local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
# Get batch of examples, and also puzzle IDs
puzzle_indices = []
puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
for i in range(local_start, local_end):
while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
puzzle_index += 1
puzzle_indices.append(puzzle_index)
batch = self._collate_batch({
"inputs": dataset["inputs"][local_start: local_end],
"labels": dataset["labels"][local_start: local_end],
"puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
})
yield set_name, batch, end_index - start_index
# Advance to next batch
start_index += self.config.global_batch_size
def _iter_train(self):
for set_name, dataset in self._data.items(): # type: ignore
# Increase epoch count
self._iters += 1
# Randomly shuffle groups
rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
start_index = 0
while start_index < group_order.size:
start_index, batch_indices, batch_puzzle_indices = _sample_batch(
rng,
group_order=group_order,
puzzle_indices=dataset["puzzle_indices"],
group_indices=dataset["group_indices"],
start_index=start_index,
global_batch_size=self.config.global_batch_size,
)
# Select current rank and collate
global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
# Drop last batch
if global_effective_batch_size < self.config.global_batch_size:
break
batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
batch = self._collate_batch({
"inputs": dataset["inputs"][batch_indices],
"labels": dataset["labels"][batch_indices],
"puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
})
yield set_name, batch, global_effective_batch_size
def __iter__(self):
worker_info = get_worker_info()
assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
self._lazy_load_dataset()
# Iterate using specified mode
if self.config.test_set_mode:
yield from self._iter_test()
else:
yield from self._iter_train()