Skip to content

Commit

Permalink
fix imports + legal en
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 16, 2024
1 parent 9994910 commit edab657
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
10 changes: 1 addition & 9 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from wtpsplit.extract import PyTorchWrapper
from wtpsplit.extract_batched import extract_batched
from wtpsplit.utils import Constants
from wtpsplit.evaluation.intrinsic import compute_statistics, corrupt
from wtpsplit.evaluation.intrinsic import compute_statistics

logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -183,11 +183,6 @@ def generate_pairs(
if len(sentences[i]) + len(sentences[i + 1]) > min_k_mer_length
]

# corrupt pairs
all_pairs = [
(corrupt(pair[0], do_lowercase, do_remove_punct), corrupt(pair[1], do_lowercase, do_remove_punct))
for pair in all_pairs
]
return all_pairs


Expand Down Expand Up @@ -234,9 +229,6 @@ def generate_k_mers(
if sum(len(sentences[i + j]) for j in range(k)) > min_k_mer_length
]

# Apply corruption to k-mers
all_k_mers = [tuple(corrupt(sentence, do_lowercase, do_remove_punct) for sentence in k_mer) for k_mer in all_k_mers]

return all_k_mers


Expand Down
6 changes: 3 additions & 3 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
from wtpsplit.train.adaptertrainer import AdapterTrainer
from wtpsplit.train.trainer import Trainer
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise
from wtpsplit.train.evaluate import evaluate_sentence
from wtpsplit.train.train import collate_fn, setup_logging
from wtpsplit.train.utils import Model
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict, corrupt
Expand Down Expand Up @@ -400,13 +400,13 @@ def maybe_pad(text):
continue
if "nllb" in dataset_name:
continue
if lang == "en" and dataset_name == "legal-all-laws":
if lang == "en" and "legal-all-laws" in dataset_name:
# not available.
print("SKIP: ", lang, dataset_name)
continue
print("RUNNING:", dataset_name, lang)
# skip langs starting with a, b, ..., k
# if not lang.startswith(tuple("k")) and not "en-de" in lang:
# if not lang.startswith(tuple("abcd")):
# print(f"Skipping {lang} {dataset_name}")
# continue
# do model stuff here; otherwise, head params would be overwritten every time
Expand Down

0 comments on commit edab657

Please sign in to comment.