Skip to content

Commit

Permalink
fix xlmr base + qol
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 14, 2024
1 parent 34de88a commit 2e80a9d
Showing 1 changed file with 7 additions and 23 deletions.
30 changes: 7 additions & 23 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
if args.adapter_path:
if args.clf_from_scratch:
model.model.classifier = torch.nn.Linear(model.model.classifier.in_features, 1)
elif model.model.classifier.out_features == 2:
# elif model.model.classifier.out_features == 2:
elif args.model_path == "xlm-roberta-base" or args.model_path == "xlm-roberta-large":
# we train XLM-R using our wrapper, needs to be adapted for adapters to be loaded
model.model.classifier = torch.nn.Linear(
model.model.classifier.in_features,
Expand All @@ -203,15 +204,6 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
with_head=True,
load_as="text",
)
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")) and not os.path.exists(
os.path.join(args.model_path, "model.safetensors")
):
model_path = os.path.join(args.model_path, dataset_name, "en")
if not os.path.exists(model_path):
model_path = args.model_path
model = PyTorchWrapper(
AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)
)
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
Expand Down Expand Up @@ -337,17 +329,7 @@ def main(args):
valid_data = None

print("Loading model...")
# if model_path does not contain a model, take first subfolder
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")) and not os.path.exists(
os.path.join(args.model_path, "model.safetensors")
):
try:
model_path = os.path.join(args.model_path, os.listdir(args.model_path)[0], "en")
except:
model_path = args.model_path
print(model_path)
else:
model_path = args.model_path
model_path = args.model_path
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
if args.adapter_path:
model_type = model.model.config.model_type
Expand All @@ -361,14 +343,16 @@ def main(args):
model.model.classifier = torch.nn.Sequential(clf, torch.nn.Linear(clf.out_features, 1))

save_str += f"{args.save_suffix}"
if args.max_n_test_sentences < sys.maxsize:
if args.max_n_test_sentences < sys.maxsize or args.max_n_test_sentences != -1:
save_str += f"_n{args.max_n_test_sentences}"
if args.max_n_test_sentences == -1:
args.max_n_test_sentences = sys.maxsize

# first, logits for everything.
f, total_test_time = load_or_compute_logits(args, model, eval_data, valid_data, save_str)

save_str += f"_u{args.threshold}"
if args.exclude_every_k > 0:
if args.exclude_every_k > 0 or "lyrics" in args.eval_data_path:
save_str += f"_k{args.exclude_every_k}"

# now, compute the intrinsic scores.
Expand Down

0 comments on commit 2e80a9d

Please sign in to comment.