Skip to content

Commit

Permalink
fix ted filter
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 17, 2024
1 parent edab657 commit 8809327
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
if args.include_langs is not None and lang_code not in args.include_langs:
continue

# print(f"Processing {lang_code}...")
if lang_code not in f:
lang_group = f.create_group(lang_code)
else:
Expand All @@ -166,19 +165,16 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
for dataset_name, dataset in tqdm(eval_data[lang_code]["sentence"].items(), desc=lang_code):
if args.skip_corrupted and "corrupted" in dataset_name:
continue
if "corrupted-asr" in dataset_name and (
"lyrics" not in dataset_name
and "short" not in dataset_name
and "code" not in dataset_name
and "ted" not in dataset_name
and "legal" not in dataset_name
if "asr" in dataset_name and not any(
x in dataset_name for x in ["lyrics", "short", "code", "ted2020", "legal"]
):
print("SKIP: ", lang_code, dataset_name)
logger.warning(f"SKIP: {lang_code} {dataset_name}")
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
print("SKIP: ", lang_code, dataset_name)
logger.warning(f"SKIP: {lang_code} {dataset_name}")
continue
if "social-media" in dataset_name:
logger.warning(f"SKIP: {lang_code} {dataset_name}")
continue
if "nllb" in dataset_name:
continue
Expand Down Expand Up @@ -214,14 +210,15 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
load_as="text",
)
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
logger.error(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
else:
dset_group = lang_group[dataset_name]

if "test_logits" not in dset_group:
# logger.warning(f"RUN: {lang_code} {dataset_name}")
test_sentences = dataset["data"]
if not test_sentences:
continue
Expand Down Expand Up @@ -348,7 +345,7 @@ def main(args):
else:
valid_data = None

print("Loading model...")
logger.warning("Loading model...")
model_path = args.model_path
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
if args.adapter_path:
Expand Down Expand Up @@ -387,7 +384,7 @@ def main(args):
if args.include_langs is not None and lang_code not in args.include_langs:
continue

print(f"Predicting {lang_code}...")
logger.warning(f"Predicting {lang_code}...")
results[lang_code] = {}
clfs[lang_code] = {}
if args.return_indices:
Expand Down Expand Up @@ -426,7 +423,7 @@ def main(args):
skip_punct=args.skip_punct,
)
if clf[0] is not None:
print(clf)
logger.warning(clf)

if isinstance(sentences[0], list):
acc_t, acc_punct = [], []
Expand Down Expand Up @@ -575,10 +572,10 @@ def main(args):
if score_punct is not None:
punct_scores.append((score_punct, lang_code))

# just for printing
# just for logging
score_t = score_t or 0.0
score_punct = score_punct or 0.0
print(f"{lang_code} {dataset_name} {score_u:.3f} {score_t:.3f} {score_punct:.3f}")
logger.warning(f"{lang_code} {dataset_name} {score_u:.3f} {score_t:.3f} {score_punct:.3f}")

# Compute statistics for each metric across all languages
results_avg = {
Expand Down

0 comments on commit 8809327

Please sign in to comment.