-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c5bc3fe
commit 95aec77
Showing
2 changed files
with
256 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# evaluate performance along various axes of sentence complexity" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## data\n", | ||
"## this is for the case where we already have the clean dataset from before and just want to evaluate other models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"1. load dataset to csv\n", | ||
"2. load the model \n", | ||
"3. run the model function on each sentence in dataset...\n", | ||
"4. batch this process\n", | ||
"5. hope for no race condition when saving" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"start\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"start\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset_max_len = 200\n", | ||
"len_of_dataset = 100000\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"len: 898502\n", | ||
"len: 850701\n", | ||
"len: 847588\n", | ||
"data loaded\n", | ||
"len: 256\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"from src.ByT5Dataset import ByT5ConstEnigmaDataset, ByT5CaesarRandomDataset, ByT5NoisyVignere2Dataset\n", | ||
"from src.evaluation import Model\n", | ||
"from src.ByT5Dataset import ByT5Dataset\n", | ||
"\n", | ||
"models = {\n", | ||
" 'caesar': Model(ByT5CaesarRandomDataset, 'caesar', 'en', 16677),\n", | ||
" 'en_constenigma': Model(ByT5ConstEnigmaDataset, 'en_constenigma', 'en', 17510),\n", | ||
" 'de_constenigma': Model(ByT5ConstEnigmaDataset, 'de_constenigma', 'de', 18065),\n", | ||
" 'cs_constenigma': Model(ByT5ConstEnigmaDataset, 'cs_constenigma', 'cs', 18066),\n", | ||
" 'en_noisevignere_checkpoint-5000': Model(ByT5NoisyVignere2Dataset, 'en_noisevignere_checkpoint-5000', 'en', 20145, True, 5000, .15), #20196\n", | ||
" 'en_noisevignere_checkpoint-10000': Model(ByT5NoisyVignere2Dataset, 'en_noisevignere_checkpoint-10000', 'en', 20145, True, 10000, .15) # 20228\n", | ||
"}\n", | ||
"\n", | ||
"evaluated_name = 'en_noisevignere_checkpoint-10000'\n", | ||
"model_metadata = models[evaluated_name]\n", | ||
"\n", | ||
"data_path = f'news.2013.{model_metadata.language}.trainlen.200.evaluation.100000.csv'\n", | ||
"data = pd.read_csv(data_path)\n", | ||
"\n", | ||
"print(\"data loaded\")\n", | ||
"print(model_metadata)\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## load model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Processed batch 1/16 in 77.57 seconds\n", | ||
"Processed batch 2/16 in 153.34 seconds\n", | ||
"Processed batch 3/16 in 218.47 seconds\n", | ||
"Processed batch 4/16 in 285.92 seconds\n", | ||
"Processed batch 5/16 in 353.50 seconds\n", | ||
"Processed batch 6/16 in 430.63 seconds\n", | ||
"Processed batch 7/16 in 491.84 seconds\n", | ||
"Processed batch 8/16 in 551.48 seconds\n", | ||
"Processed batch 9/16 in 614.72 seconds\n", | ||
"Processed batch 10/16 in 674.50 seconds\n", | ||
"Processed batch 11/16 in 731.88 seconds\n", | ||
"Processed batch 12/16 in 797.07 seconds\n", | ||
"Processed batch 13/16 in 871.31 seconds\n", | ||
"Processed batch 14/16 in 941.20 seconds\n", | ||
"Processed batch 15/16 in 1013.33 seconds\n", | ||
"Processed batch 16/16 in 1087.18 seconds\n", | ||
"Average errors: 88.78125\n", | ||
"Median errors: 87\n", | ||
"Mode errors: 71\n", | ||
"#############################################\n", | ||
"avg: 88.78125\n", | ||
"med: 87\n", | ||
"mode: 71\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from transformers import ByT5Tokenizer, T5ForConditionalGeneration\n", | ||
"from src.utils import levensthein_distance, print_avg_median_mode_error\n", | ||
"from transformers import pipeline, logging\n", | ||
"from src.ByT5Dataset import ByT5CaesarRandomDataset, ByT5ConstEnigmaDataset\n", | ||
"import torch\n", | ||
"import time\n", | ||
"\n", | ||
"logging.set_verbosity(logging.ERROR)\n", | ||
"\n", | ||
"tokenizer = ByT5Tokenizer()\n", | ||
"\n", | ||
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", | ||
"model = T5ForConditionalGeneration.from_pretrained(model_metadata.path())\n", | ||
"model.to(device) # type: ignore\n", | ||
"\n", | ||
"dataset_class = model_metadata.dataset_class\n", | ||
"if model_metadata.noise_proportion is not None:\n", | ||
" dataset = dataset_class(dataset=data.text, max_length=dataset_max_len, noise_proportion=model_metadata.noise_proportion) # type: ignore , not sound but will work in my usecase\n", | ||
"else:\n", | ||
" dataset = dataset_class(dataset=data.text, max_length=dataset_max_len) # type: ignore , not sound but will work in my usecase\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"translate = pipeline(\"translation\", model=model, tokenizer=tokenizer, device=device)\n", | ||
"\n", | ||
"error_col_name = f'{model_metadata.name}_error_count'\n", | ||
"generated_text_col_name = f'{model_metadata.name}_generated_text'\n", | ||
"\n", | ||
"data[error_col_name] = 0\n", | ||
"data[generated_text_col_name] = ''\n", | ||
"\n", | ||
"batch_size = 64\n", | ||
"data = data.reset_index(drop=True)\n", | ||
"for i in range(0, len(dataset), batch_size):\n", | ||
" t0 = time.time()\n", | ||
" batch = dataset[i:i+batch_size]\n", | ||
" input_texts = batch['input_text'] \n", | ||
" output_texts = batch['output_text'] \n", | ||
"\n", | ||
" # Generate translations in batches\n", | ||
" generated_texts = translate(input_texts, max_length=(dataset_max_len + 1) * 2)\n", | ||
" generated_texts = [t['translation_text'] for t in generated_texts] # type: ignore \n", | ||
"\n", | ||
" # Calculate errors and update DataFrame\n", | ||
" errors = [levensthein_distance(gen, out) for gen, out in zip(generated_texts, output_texts)]\n", | ||
" data.loc[i:i+batch_size-1, generated_text_col_name] = generated_texts\n", | ||
" data.loc[i:i+batch_size-1, error_col_name] = errors\n", | ||
" t1 = time.time()\n", | ||
"\n", | ||
" print(f\"Processed batch {i // batch_size + 1}/{len(dataset) // batch_size} in {t1 - t0:.2f} seconds\")\n", | ||
"\n", | ||
"\n", | ||
"avg, med, mode = print_avg_median_mode_error(data[error_col_name].tolist())\n", | ||
"print(\"#############################################\")\n", | ||
"\n", | ||
"print(\"avg:\", avg)\n", | ||
"print(\"med:\", med)\n", | ||
"print(\"mode:\", mode)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"data.to_csv(f\"news.2013.{model_metadata.language}.trainlen.200.evaluation.100000.intermediate.{model_metadata.name}.csv\", index=False)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "enigmavenv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters