From 06193b76a5d24bb79eaad20d4c2b21548c0eb6a8 Mon Sep 17 00:00:00 2001 From: Luke Hsiao Date: Tue, 28 Apr 2020 10:35:39 -0700 Subject: [PATCH] refactor(supervision/labeler): raise error if gold contains ABSTAIN If a user calls get_gold_labels before loading gold labels, they will get a label matrix with all values being -1 (ABSTAIN) L_gold_train = labeler.get_gold_labels(train_cands, annotator='gold') print(L_gold_train[0].reshape(-1)) [-1. -1. -1. ... -1. -1. -1.] Without any indicator that that is incorrect, if they do an LFAnalysis they will have 0 empirical accuracy. To be more friendly to the user, we throw an error to indicate that this is an error. Closes #403. --- src/fonduer/supervision/labeler.py | 14 +++++++++++++- tests/e2e/test_e2e.py | 4 ++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/fonduer/supervision/labeler.py b/src/fonduer/supervision/labeler.py index b535214f0..af0b257a5 100644 --- a/src/fonduer/supervision/labeler.py +++ b/src/fonduer/supervision/labeler.py @@ -404,17 +404,29 @@ def get_gold_labels( :param annotator: A specific annotator key to get labels for. Default None. :type annotator: str + :raises ValueError: If get_gold_labels is called before gold labels are + loaded, the result will contain ABSTAIN values. We raise a + ValueError to help indicate this potential mistake to the user. :return: A list of MxN dense matrix where M are the candidates and N is the annotators. If annotator is provided, return a list of Mx1 matrix. :rtype: list[np.ndarray] """ - return [ + gold_labels = [ unshift_label_matrix(m) for m in get_sparse_matrix( self.session, GoldLabelKey, cand_lists, key=annotator ) ] + for cand_labels in gold_labels: + if ABSTAIN in cand_labels: + raise ValueError( + "Gold labels contain ABSTAIN labels. " + "Did you load gold labels beforehand?" + ) + + return gold_labels + def get_label_matrices(self, cand_lists: List[List[Candidate]]) -> List[np.ndarray]: """Load dense matrix of Labels for each candidate_class. diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index d13639950..1f043a2b5 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -337,6 +337,10 @@ def test_e2e(): labeler = Labeler(session, [PartTemp, PartVolt]) + # This should raise an error, since gold labels are not yet loaded. + with pytest.raises(ValueError): + _ = labeler.get_gold_labels(train_cands, annotator="gold") + labeler.apply( docs=last_docs, lfs=[[gold], [gold]],