Skip to content

Commit

Permalink
generalist dataloader created with sentiment as additional example task
Browse files Browse the repository at this point in the history
  • Loading branch information
Maitreyapatel committed Mar 10, 2023
1 parent 7f607d1 commit 5e6c8f8
Show file tree
Hide file tree
Showing 16 changed files with 963 additions and 64 deletions.
File renamed without changes.
748 changes: 748 additions & 0 deletions data/parrot_sentiment140.csv

Large diffs are not rendered by default.

38 changes: 27 additions & 11 deletions reliability_score/augmentation/augments.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,17 @@ def augment(self):


class parrot_paraphraser(Augmentation):
def __init__(self, __name__="parrot", dataset=None, csv_path=None):
def __init__(
self,
__name__="parrot",
dataset=None,
dataset_name="none",
csv_path=None,
cols=None,
):
super().__init__(__name__, dataset)
self.csv_path = csv_path
self.csv_path = csv_path.format(dataset_name)
self.cols = cols

from parrot import Parrot

Expand All @@ -69,17 +77,25 @@ def perform_augmentation(self, dataset):
new_dataset = {k: [] for k in datacols}

for i in tqdm(range(len(dataset))):
newh = self.parrot.augment(input_phrase=dataset["hypothesis"][i], use_gpu=True)
newp = self.parrot.augment(input_phrase=dataset["premise"][i], use_gpu=True)

if newp and newh:
for j in range(min(len(newh), len(newp))):
new_dataset["premise"].append(newp[j][0])
new_dataset["hypothesis"].append(newh[j][0])
tmp_data = {}
tmp_flag = True
tmp_min = 10 # fixed maximum parrot augmentations
for col in self.cols:
tmp_ = self.parrot.augment(input_phrase=dataset[col][i], use_gpu=True)
tmp_data[col] = tmp_
if not tmp_:
tmp_flag = False
else:
tmp_min = min(tmp_min, len(tmp_))

if tmp_flag:
for j in range(tmp_min):
for col in self.cols:
new_dataset[col].append(tmp_data[col][0])
new_dataset["label"].append(dataset["label"][i])
new_dataset["mapping"].append(i)
for k in datacols:
if k not in ["premise", "hypothesis", "label", "mapping"]:
if k not in ["label", "mapping"] + self.cols:
new_dataset[k].append(dataset[k][i])

new_dataset = pd.DataFrame(new_dataset)
Expand All @@ -93,5 +109,5 @@ def augment(self):
new_dataset = Dataset.from_pandas(pd.read_csv(self.csv_path, delimiter="\t"))
self.dataset = new_dataset.cast_column(
"label",
ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]),
self.dataset.features["label"],
)
4 changes: 3 additions & 1 deletion reliability_score/configs/augmentation/parrot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ parrot_paraphraser:
_target_: reliability_score.augmentation.augments.parrot_paraphraser
_partial_: true
__name__: "parrot"
csv_path: "./data/parrot_mnli.csv"
dataset_name: ${datamodule.dataset_specific_args.name}
csv_path: "./data/parrot_{}.csv"
cols: ${datamodule.dataset_specific_args.cols}
3 changes: 3 additions & 0 deletions reliability_score/configs/augmentation/sentiment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
defaults:
- default.yaml
- parrot.yaml
33 changes: 33 additions & 0 deletions reliability_score/configs/custom_model/roberta_base_sentiment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## below parameters are self-explanatory
model_name: "cardiffnlp/twitter-xlm-roberta-base-sentiment"
model_type: "discriminative" ## you can only have following types: "encode-decode","decoder-only","bart","discriminative","shared","hybrid","t5",
huggingface_class: null
decoder_model_name: null
model_path: null ## provide the path if you have custom trained model using transformers library
tie_embeddings: false
label: null
tie_encoder_decoder: false
pipeline: null

additional_model_inputs: null ## specify the additional pre-defined input to model like bean_search for generative models

tokenizer:
model_name: ${..model_name} ## only specify the name from huggingface if it's different than the actual model
label2id: ## this will vary based on the evaluation data, please refer to the your selected dataset config
negative: 0
neutral: 1
positive: 2
args:
truncation: true
padding: "max_length"

## use following dataset pre-processing steps
data_processing:
header: null ## prompt header for input data?
footer: null ## prompt header for signling output?
separator: " [SEP] " ## what is separator token? leave `null` for generative models
columns:
null
## you should define this only for generative or for prompt eng. models as shown below
# premise: null
# hypothesis: null
13 changes: 12 additions & 1 deletion reliability_score/configs/datamodule/mnli.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
_target_: reliability_score.datamodules.mnli_datamodule.MNLIDataModule
_target_: reliability_score.datamodules.common_datamodule.GeneralDataModule
data_dir: ${paths.data_dir}
batch_size: 1
num_workers: 0
pin_memory: False
tokenizer_data: ${custom_model.tokenizer}
model_type: ${custom_model.model_type}
data_processing: ${custom_model.data_processing}
dataset_specific_args:
label_conversion: null
label2id:
0: "entailment"
1: "neutral"
2: "contradiction"
cols: ["premise", "hypothesis"]
name: multi_nli
split: validation_matched
remove_cols: ["promptID", "pairID"]
label_col: "label"
22 changes: 22 additions & 0 deletions reliability_score/configs/datamodule/sentiment140.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_target_: reliability_score.datamodules.common_datamodule.GeneralDataModule
data_dir: ${paths.data_dir}
batch_size: 1
num_workers: 0
pin_memory: False
tokenizer_data: ${custom_model.tokenizer}
model_type: ${custom_model.model_type}
data_processing: ${custom_model.data_processing}
dataset_specific_args:
label_conversion:
0: 0
2: 1
4: 2
label2id:
0: "negative"
1: "neutral"
2: "positive"
cols: ["text"]
name: sentiment140
split: test
remove_cols: ["query", "user", "date"]
label_col: "sentiment"
2 changes: 1 addition & 1 deletion reliability_score/configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defaults:
- hydra: default.yaml
- custom_model: bert_base_uncased.yaml

- experiment: null
- task: null

task_name: "eval"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
defaults:
- override /datamodule: mnli.yaml
- override /model: inference.yaml
- override /callbacks: mnli.yaml
- override /callbacks: general_evals.yaml
- override /trainer: default.yaml
- override /custom_model: bert_base_uncased.yaml

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:
- override /augmentation: mnli.yaml
- override /datamodule: mnli.yaml
- override /model: inference.yaml
- override /callbacks: mnli.yaml
- override /callbacks: general_evals.yaml
- override /trainer: default.yaml
- override /custom_model: bert_base_uncased.yaml
- override /logger: csv.yaml
Expand Down
20 changes: 20 additions & 0 deletions reliability_score/configs/task/sentiment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /augmentation: sentiment.yaml
- override /datamodule: sentiment140.yaml
- override /model: inference.yaml
- override /callbacks: general_evals.yaml
- override /trainer: default.yaml
- override /custom_model: bert_base_uncased.yaml
- override /logger: csv.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["sentiment", "test"]

seed: 42
Original file line number Diff line number Diff line change
@@ -1,52 +1,24 @@
import logging

import logging
from typing import Any, Dict, Optional, Tuple

import torch
from datasets import concatenate_datasets, load_dataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from transformers import AutoTokenizer, T5Tokenizer


class mnli_tokenization:
def __init__(
self,
model_name: str,
model_type: str,
is_generative_model: bool,
tokenizer_args: dict,
data_processing: dict,
label2id: dict,
cols: list,
label_col: str,
):
if model_type != "t5":
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.is_generative_model = is_generative_model
self.data_processing = data_processing
self.tokenizer_args = tokenizer_args
self.label2id = label2id
self.label_col = label_col
self.cols = cols

def process(self, example):
return self.tokenizer(example["input_data"], **self.tokenizer_args)

from reliability_score.datamodules.utils import (
process_label2id,
general_tokenization,
conversion_process,
)

def process_label2id(gt_label2id, pred_label2id):
assert len(gt_label2id) == len(pred_label2id)

dataset_converion = {}
for i in list(gt_label2id.keys()):
dataset_converion[i] = pred_label2id[gt_label2id[i]]
return dataset_converion


class MNLIDataModule(LightningDataModule):
class GeneralDataModule(LightningDataModule):
def __init__(
self,
dataset_specific_args: dict,
tokenizer_data: dict,
model_type: str,
data_dir: str = "data/",
Expand All @@ -58,8 +30,13 @@ def __init__(
):
super().__init__()

self.label2id = {0: "entailment", 1: "neutral", 2: "contradiction"}
self.cols = ["premise", "hypothesis"]
self.label2id = dataset_specific_args["label2id"]
self.cols = dataset_specific_args["cols"]
self.dataset_name = dataset_specific_args["name"]
self.dataset_split = dataset_specific_args["split"]
self.dataset_rmcols = dataset_specific_args["remove_cols"]
self.label_col = dataset_specific_args["label_col"]
self.label_conversion = dataset_specific_args["label_conversion"]

self.batch_size = batch_size
self.num_workers = num_workers
Expand All @@ -71,14 +48,14 @@ def __init__(
self.data_processing = data_processing
self.is_generative_model = False if model_type == "discriminative" else True

self.tokenization = mnli_tokenization(
self.tokenization = general_tokenization(
model_name=self.tokenizer_data["model_name"],
is_generative_model=self.is_generative_model,
tokenizer_args=self.tokenizer_data["args"],
data_processing=self.data_processing,
label2id=self.label2id,
model_type=model_type,
label_col="label",
label_col=self.label_col,
cols=self.cols,
)

Expand All @@ -102,7 +79,9 @@ def custom_prepocess(self, dataset):
if self.data_processing.columns:
for column_name, column_prefix in self.data_processing.columns.items():
dataset = dataset.map(
lambda example: {column_name: " ".join([column_prefix, example[column_name]])},
lambda example: {
column_name: " ".join([column_prefix, example[column_name]])
},
batched=False,
)

Expand Down Expand Up @@ -139,8 +118,19 @@ def custom_prepocess(self, dataset):
return dataset

def prepare_data(self):
self.data_test = load_dataset("multi_nli", split="validation_matched")
self.data_test = self.data_test.remove_columns(["promptID", "pairID"])
self.data_test = load_dataset(self.dataset_name, split=self.dataset_split)
self.data_test = self.data_test.remove_columns(self.dataset_rmcols)
if self.label_col != "label":
self.data_test = self.data_test.rename_column(self.label_col, "label")
if self.label_conversion:
self.data_test = self.data_test.map(
lambda batch: {"converted_label": self.label_conversion[batch["label"]]},
batched=False,
remove_columns=["label"],
)
self.data_test = self.data_test.rename_column("converted_label", "label")
# cvp = conversion_process(self.label_conversion)
# self.data_test.map(cvp.process, batched=True)

keys = [i for i in range(len(self.data_test))]
self.data_test = self.data_test.add_column("primary_key", keys)
Expand All @@ -154,7 +144,9 @@ def prepare_data(self):
logging.info("Performing tokenization...")
old_columns = set(list(self.data_test.features.keys()))
self.data_test = self.data_test.map(self.tokenization.process, batched=True)
self.label_conversion = process_label2id(self.label2id, self.tokenizer_data.label2id)
self.label_conversion = process_label2id(
self.label2id, self.tokenizer_data.label2id
)
self.data_test = self.data_test.map(
lambda batch: {"converted_label": self.label_conversion[batch["label"]]},
batched=False,
Expand Down Expand Up @@ -185,11 +177,12 @@ def prepare_data(self):
]
+ list(new_columns - old_columns),
)
# self.data_test = self.data_test.align_labels_with_mapping(self.label2id, "label")

def setup(self, stage: Optional[str] = None):
if not self.data_test:
logging.error("It seems that dataset object was not declared. Attempting it again.")
logging.error(
"It seems that dataset object was not declared. Attempting it again."
)
self.prepare_data()

def train_dataloader(self):
Expand Down
Loading

0 comments on commit 5e6c8f8

Please sign in to comment.