Skip to content

Commit

Permalink
use new data, less verbose
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 4, 2024
1 parent a9f7fd9 commit 8cc90c0
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Args:
# }
# }
# TODO: for songs/etc., maybe feed in each sample separately?
eval_data_path: str = "data/all_data_24_04.pth"
eval_data_path: str = "data/all_data_02_05.pth"
valid_text_path: str = None # "data/sentence/valid.parquet"
device: str = "cpu"
block_size: int = 512
Expand Down Expand Up @@ -109,7 +109,7 @@ def process_logits(text, model, lang_code, args):
block_size=args.block_size,
batch_size=args.batch_size,
pad_last_batch=True,
verbose=True,
verbose=False,
)
logits = logits[0]
if offsets_mapping is not None:
Expand Down Expand Up @@ -142,16 +142,16 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
# file is a csv: l1,l2,...
use_langs = f.read().strip().split(",")
else:
use_langs = Constants.LANGINFO.index
use_langs = eval_data.keys()

total_test_time = 0 # Initialize total test processing time

with h5py.File(logits_path, "a") as f, torch.no_grad():
for lang_code in use_langs:
for lang_code in tqdm(use_langs, desc="Languages"):
if args.include_langs is not None and lang_code not in args.include_langs:
continue

print(f"Processing {lang_code}...")
# print(f"Processing {lang_code}...")
if lang_code not in f:
lang_group = f.create_group(lang_code)
else:
Expand All @@ -176,7 +176,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
lang_group.create_dataset("valid", data=valid_logits)

# eval data
for dataset_name, dataset in eval_data[lang_code]["sentence"].items():
for dataset_name, dataset in tqdm(eval_data[lang_code]["sentence"].items(), desc=lang_code):
if args.skip_corrupted and "corrupted" in dataset_name:
continue
try:
Expand All @@ -195,14 +195,14 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
model_path = os.path.join(args.model_path, dataset_name, "en")
if not os.path.exists(model_path):
model_path = args.model_path
print(model_path)
# print(model_path)
model = PyTorchWrapper(
AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)
)
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
print(dataset_name)
# print(dataset_name)
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
else:
Expand Down

0 comments on commit 8cc90c0

Please sign in to comment.