-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
66 lines (55 loc) · 1.76 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations
import argparse
import os
from dataclasses import dataclass
import fsspec
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizerBase
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class MyDataset(Dataset):
texts: list[str]
labels: np.ndarray
tokenizer: PreTrainedTokenizerBase
max_length: int = 128
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, i: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
encoding = self.tokenizer(
self.texts[i],
max_length=self.max_length,
padding="max_length",
truncation=True,
)
return (
torch.tensor(encoding["input_ids"], dtype=torch.int32),
torch.tensor(encoding["attention_mask"], dtype=torch.int32),
torch.tensor(self.labels[i], dtype=torch.float32),
)
def create_train_dataloader(
args: argparse.Namespace, tokenizer: PreTrainedTokenizerBase
) -> DataLoader:
with fsspec.open(args.dataset) as fp:
data = pd.read_csv(fp)
labels = []
for target in args.target_columns:
name, direction = target.split(":")[:2]
labels.append(data[name] * (1 if direction == "max" else -1))
dataset = MyDataset(
texts=data["safe"],
labels=np.stack(labels, axis=1),
tokenizer=tokenizer,
max_length=args.max_length,
)
return DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
drop_last=True,
generator=torch.Generator().manual_seed(args.shuffle_seed),
persistent_workers=True,
)