Skip to content

Commit

Permalink
(refactor) src
Browse files Browse the repository at this point in the history
(add) 20 reproducible training noisy random vignere2
  • Loading branch information
JanProvaznik committed Jan 3, 2024
1 parent a89f1ff commit c5bc3fe
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 16 deletions.
6 changes: 3 additions & 3 deletions evaluation_batchedgpuevaluate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@
"# aaaaaa\n",
"model_ids = {\n",
" 'caesar': 'slurm_16677',\n",
" 'en': 'slurm_17510',\n",
" 'de': 'slurm_18065',\n",
" 'cs': 'slurm_18066'\n",
" 'en_constenigma': 'slurm_17510',\n",
" 'de_constenigma': 'slurm_18065',\n",
" 'cs_constenigma': 'slurm_18066'\n",
"}\n",
"model_id = model_ids[lang]\n",
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
Expand Down
416 changes: 416 additions & 0 deletions reproducible/20_vignere_noisy_random_news_en.ipynb

Large diffs are not rendered by default.

26 changes: 22 additions & 4 deletions src/ByT5Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __getitem__(self, idx) -> dict:
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(),
"labels": encoding_labels["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"input_ids": encoding["input_ids"].squeeze(), # type: ignore
"labels": encoding_labels["input_ids"].squeeze(), # type: ignore
"attention_mask": encoding["attention_mask"].squeeze(), # type: ignore
"input_text": input_text,
"output_text": output_text,
}
Expand Down Expand Up @@ -186,4 +186,22 @@ class ByT5ConstEnigmaDataset(ByT5DatasetOnlyPreprocessCiphertext):

def __init__(self, data, max_length) -> None:
const_enigma = ciphers.make_const_enigma()
super().__init__(const_enigma, data, max_length)
super().__init__(const_enigma, data, max_length)

class ByT5NoisyConstEnigmaDataset(ByT5DatasetOnlyPreprocessCiphertext):
"""
Dataset using Enigma cipher with a constant key and noise.
"""

def __init__(self, data, max_length, noise_proportion=.1) -> None:
noisy_const_enigma = ciphers.make_const_enigma(noise_proportion=noise_proportion)
super().__init__(noisy_const_enigma, data, max_length)


class ByT5NoisyVignere2Dataset(ByT5DatasetOnlyPreprocessCiphertext):
"""
Dataset using noisy 2-letter Vignere cipher.
"""

def __init__(self, data, max_length, noise_proportion=.15) -> None:
super().__init__(functools.partial(ciphers.noisy_random_vignere2, noise_proportion=noise_proportion), data, max_length)
6 changes: 3 additions & 3 deletions src/ciphers.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def random_vignere(text: str, key_length: int) -> str:


def noisy_random_vignere(
text: str, key_length: int, noise_proportion: float = 0.1
text: str, key_length: int, noise_proportion: float = 0.15
) -> str:
"""Vigenère cipher with a random key of specified length and noise."""
noisy_text = add_noise_to_text(text, noise_proportion)
Expand All @@ -213,11 +213,11 @@ def random_vignere3(text: str) -> str:
return random_vignere(text, 3)


def noisy_random_vignere2(text: str, noise_proportion: float = 0.1) -> str:
def noisy_random_vignere2(text: str, noise_proportion: float = 0.15) -> str:
"""Vigenère cipher with a random 2-letter key and noise."""
return noisy_random_vignere(text, 2, noise_proportion)


def noisy_random_vignere3(text: str, noise_proportion: float = 0.1) -> str:
def noisy_random_vignere3(text: str, noise_proportion: float = 0.15) -> str:
"""Vigenère cipher with a random 3-letter key and noise."""
return noisy_random_vignere(text, 3, noise_proportion)
5 changes: 3 additions & 2 deletions src/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import logging
import re
from typing import Optional
from unidecode import unidecode


Expand Down Expand Up @@ -46,7 +47,7 @@ def generate_random_dataset(


def load_dataset(
rows: int, min_length: int, max_length: int, file_path: str, seed: int = 42, exclude_length: int = None
rows: int, min_length: int, max_length: int, file_path: str, seed: int = 42, exclude_length: Optional[int] = None
) -> list[str]:
"""Samples a dataset from a file and truncates each rows to a random length generated for it from a range.
Expand Down Expand Up @@ -165,4 +166,4 @@ def weird(text: str) -> bool:
if comma_period_count > 15:
return True

return False
return False
9 changes: 5 additions & 4 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ def download_newscrawl(year=2012, language="en") -> None:
newscrawl_url = f"https://data.statmt.org/news-crawl/{language}/"
filename = f"news.{year}.{language}.shuffled.deduped.gz"
url = newscrawl_url + filename
logging.info(f"Downloading {filename} from {url}")
logging.info("Downloading %s from %s", filename, url)
os.system(f"wget {url}")
os.system(f"gunzip {filename}")
logging.info(
f"Downloaded and extracted {filename}",
f"full path: {os.path.abspath(filename)}",
"Downloaded and extracted %s, full path: %s",
filename,
os.path.abspath(filename),
)


Expand Down Expand Up @@ -86,7 +87,7 @@ def create_detect_language(lang="en"):
def detect_language(text):
# print('\n' in text.strip())
# return 'aabb'
detected = model.predict(text.strip())[0][0][-2:] # the last two characters are the language in the label
detected = model.predict(text.strip())[0][0][-2:] # type: ignore # the last two characters are the language in the label
return detected == lang

return detect_language

0 comments on commit c5bc3fe

Please sign in to comment.