Skip to content

Commit

Permalink
refactor(supervision/labeler): raise error if gold contains ABSTAIN
Browse files Browse the repository at this point in the history
If a user calls get_gold_labels before loading gold labels, they will
get a label matrix with all values being -1 (ABSTAIN)

    L_gold_train = labeler.get_gold_labels(train_cands, annotator='gold')
    print(L_gold_train[0].reshape(-1))
    [-1. -1. -1. ... -1. -1. -1.]

Without any indicator that that is incorrect, if they do an LFAnalysis
they will have 0 empirical accuracy.

To be more friendly to the user, we throw an error to indicate that this
is an error.

Closes #403.
  • Loading branch information
lukehsiao committed Apr 29, 2020
1 parent ce3357a commit 06193b7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/fonduer/supervision/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,17 +404,29 @@ def get_gold_labels(
:param annotator: A specific annotator key to get labels for. Default
None.
:type annotator: str
:raises ValueError: If get_gold_labels is called before gold labels are
loaded, the result will contain ABSTAIN values. We raise a
ValueError to help indicate this potential mistake to the user.
:return: A list of MxN dense matrix where M are the candidates and N is the
annotators. If annotator is provided, return a list of Mx1 matrix.
:rtype: list[np.ndarray]
"""
return [
gold_labels = [
unshift_label_matrix(m)
for m in get_sparse_matrix(
self.session, GoldLabelKey, cand_lists, key=annotator
)
]

for cand_labels in gold_labels:
if ABSTAIN in cand_labels:
raise ValueError(
"Gold labels contain ABSTAIN labels. "
"Did you load gold labels beforehand?"
)

return gold_labels

def get_label_matrices(self, cand_lists: List[List[Candidate]]) -> List[np.ndarray]:
"""Load dense matrix of Labels for each candidate_class.
Expand Down
4 changes: 4 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def test_e2e():

labeler = Labeler(session, [PartTemp, PartVolt])

# This should raise an error, since gold labels are not yet loaded.
with pytest.raises(ValueError):
_ = labeler.get_gold_labels(train_cands, annotator="gold")

labeler.apply(
docs=last_docs,
lfs=[[gold], [gold]],
Expand Down

0 comments on commit 06193b7

Please sign in to comment.