Skip to content

Commit

Permalink
add LLM eval sentence fn
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 14, 2024
1 parent 2211bec commit b971b1c
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def evaluate_sentences(

assert len(labels) == len(predictions)

return f1_score(labels, predictions), {
"recall": recall_score(labels, predictions),
"precision": precision_score(labels, predictions),
return f1_score(labels, predictions, zero_division=0), {
"recall": recall_score(labels, predictions, zero_division=0),
"precision": precision_score(labels, predictions, zero_division=0),
# pairwise: ignore end-of-text label
# only correct if we correctly predict the single newline in between the sentence pair
# --> no false positives, no false negatives allowed!
Expand All @@ -84,6 +84,40 @@ def evaluate_sentences(
"length": len(labels),
}

def evaluate_sentences_llm(
labels, predictions, return_indices: bool = False, exclude_every_k: int = 0
):

assert len(labels) == len(predictions)

if exclude_every_k > 0:
true_end_indices = np.where(labels == 1)[0]
# every k-th from those where labels are 1
indices_to_remove = true_end_indices[exclude_every_k-1::exclude_every_k]

# mask for indices to keep
mask = np.ones_like(labels, dtype=bool)
mask[indices_to_remove] = False
mask[-1] = False # last is always excluded

# remove indices
labels = labels[mask]
predictions = predictions[mask]

assert len(labels) == len(predictions)

return {
"f1": f1_score(labels, predictions, zero_division=0),
"recall": recall_score(labels, predictions, zero_division=0),
"precision": precision_score(labels, predictions, zero_division=0),
# pairwise: ignore end-of-text label
# only correct if we correctly predict the single newline in between the sentence pair
# --> no false positives, no false negatives allowed!
"correct_pairwise": int(np.all(labels[:-1] == predictions[:-1])),
"true_indices": np.where(labels)[0].tolist() if return_indices else None,
"predicted_indices": np.where(predictions)[0].tolist() if return_indices else None,
"length": len(labels),
}

def train_mixture(lang_code, original_train_x, train_y, n_subsample=None, features=None, skip_punct: bool = False):
original_train_x = torch.from_numpy(original_train_x).float()
Expand Down

0 comments on commit b971b1c

Please sign in to comment.