Skip to content

Commit 585e281

Browse files
LLM judge judgebench benchmarks (#1800)
* Move Metric class into its own file To avoid import cycle issues Signed-off-by: Martín Santillán Cooper <[email protected]> * Add MetricInferenceEngine Signed-off-by: Martín Santillán Cooper <[email protected]> * Add toxic chat LLM judge benchmarks Signed-off-by: Martín Santillán Cooper <[email protected]> * Fix imports Signed-off-by: elronbandel <[email protected]> * Update cards to use LoadJsonFIle Signed-off-by: elronbandel <[email protected]> * Fix empty template list Signed-off-by: elronbandel <[email protected]> * Another try Signed-off-by: elronbandel <[email protected]> --------- Signed-off-by: Martín Santillán Cooper <[email protected]> Signed-off-by: elronbandel <[email protected]> Co-authored-by: elronbandel <[email protected]>
1 parent 85c07cf commit 585e281

File tree

11 files changed

+553
-206
lines changed

11 files changed

+553
-206
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
from unitxt.blocks import (
3+
TaskCard,
4+
)
5+
from unitxt.catalog import add_to_catalog
6+
from unitxt.llm_as_judge_constants import DirectCriteriaCatalogEnum
7+
from unitxt.loaders import LoadJsonFile
8+
from unitxt.operators import Copy, MapInstanceValues, Rename
9+
from unitxt.task import Task
10+
from unitxt.test_utils.card import test_card
11+
12+
card = TaskCard(
13+
loader=LoadJsonFile(
14+
files={
15+
"train": "https://raw.githubusercontent.com/dmg-illc/JUDGE-BENCH/refs/heads/master/data/toxic_chat/toxic_chat_train.json",
16+
"test":"https://raw.githubusercontent.com/dmg-illc/JUDGE-BENCH/refs/heads/master/data/toxic_chat/toxic_chat_test.json"
17+
},
18+
data_classification_policy=["public"],
19+
data_field="instances",
20+
),
21+
preprocess_steps=[
22+
Rename(field="instance", to_field="text"),
23+
Rename(field="annotations/jailbreaking/majority_human", to_field="label"),
24+
MapInstanceValues(mappers={
25+
"label": {
26+
"0": "No",
27+
"1": "Yes"
28+
},
29+
}),
30+
Copy(field="label", to_field="label_value"),
31+
MapInstanceValues(mappers={
32+
"label_value": DirectCriteriaCatalogEnum.JAILBREAK_USER_MESSAGE.value.option_map,
33+
}),
34+
],
35+
task=Task(
36+
input_fields={"text": str, "label": str},
37+
reference_fields={"label_value": float},
38+
prediction_type=float,
39+
metrics=[
40+
"metrics.spearman",
41+
"metrics.accuracy"
42+
],
43+
default_template="templates.empty[postprocessors=[processors.cast_to_float_return_nan_if_failed]]"
44+
),
45+
templates=["templates.empty[postprocessors=[processors.cast_to_float_return_nan_if_failed]]"]
46+
)
47+
48+
49+
test_card(card, demos_taken_from="test", strict=False)
50+
51+
add_to_catalog(
52+
card,
53+
"cards.judege_bench.toxic_chat.jailbreaking",
54+
overwrite=True,
55+
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
from unitxt.blocks import (
3+
MapInstanceValues,
4+
Rename,
5+
TaskCard,
6+
)
7+
from unitxt.catalog import add_to_catalog
8+
from unitxt.llm_as_judge_constants import DirectCriteriaCatalogEnum
9+
from unitxt.loaders import LoadJsonFile
10+
from unitxt.operators import Copy
11+
from unitxt.task import Task
12+
from unitxt.test_utils.card import test_card
13+
14+
card = TaskCard(
15+
loader=LoadJsonFile(
16+
files={
17+
"train": "https://raw.githubusercontent.com/dmg-illc/JUDGE-BENCH/refs/heads/master/data/toxic_chat/toxic_chat_train.json",
18+
"test":"https://raw.githubusercontent.com/dmg-illc/JUDGE-BENCH/refs/heads/master/data/toxic_chat/toxic_chat_test.json"
19+
},
20+
data_classification_policy=["public"],
21+
data_field="instances",
22+
),
23+
preprocess_steps=[
24+
Rename(field="instance", to_field="text"),
25+
Rename(field="annotations/toxicity/majority_human", to_field="label"),
26+
MapInstanceValues(mappers={
27+
"label": {
28+
"0": "No",
29+
"1": "Yes"
30+
},
31+
}),
32+
Copy(field="label", to_field="label_value"),
33+
MapInstanceValues(mappers={
34+
"label_value": DirectCriteriaCatalogEnum.TOXICITY.value.option_map,
35+
}),
36+
],
37+
task=Task(
38+
input_fields={"text": str, "label": str},
39+
reference_fields={"label_value": float},
40+
prediction_type=float,
41+
metrics=[
42+
"metrics.spearman",
43+
"metrics.accuracy"
44+
],
45+
default_template="templates.empty[postprocessors=[processors.cast_to_float_return_nan_if_failed]]"
46+
),
47+
templates=["templates.empty[postprocessors=[processors.cast_to_float_return_nan_if_failed]]"]
48+
)
49+
50+
test_card(card, demos_taken_from="test", strict=False)
51+
52+
add_to_catalog(
53+
card,
54+
"cards.judege_bench.toxic_chat.toxicity",
55+
overwrite=True,
56+
)

src/unitxt/base_metric.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from abc import abstractmethod
2+
from typing import (
3+
Any,
4+
Dict,
5+
List,
6+
Union,
7+
)
8+
9+
from .artifact import Artifact
10+
from .dataclass import (
11+
AbstractField,
12+
)
13+
from .deprecation_utils import deprecation
14+
from .error_utils import Documentation, UnitxtWarning
15+
from .stream import Stream
16+
from .type_utils import Type, isoftype, parse_type_string, to_type_string
17+
18+
19+
@deprecation(
20+
version="2.0.0",
21+
msg="use regular type instead of strings (e.g Dict[str] instead of 'Dict[str]')",
22+
)
23+
def parse_string_types_instead_of_actual_objects(obj):
24+
return parse_type_string(obj)
25+
26+
class Metric(Artifact):
27+
main_score: str = AbstractField()
28+
# Override 'prediction_type' with the expected type of predictions
29+
# and references. Example: "List[str]", "List[Dict]"", "string".
30+
# If left with default None, a warning will be displayed.
31+
# In future versions of unitxt, this will be an error.
32+
prediction_type: Union[Type, str] = Any
33+
34+
# Standard metrics can receive multiple references per predictions (in a list)
35+
# Some metrics support only a single reference per prediction (one element in the list)
36+
single_reference_per_prediction: bool = False
37+
38+
#
39+
# Used to add a prefix to all score, except the "score_name" and "score" fields.
40+
# This is used to distinguish two scores of the same metrics, operating on different fields of the task
41+
#
42+
score_prefix: str = ""
43+
44+
def prepare_args(self):
45+
super().prepare_args()
46+
if isinstance(self.prediction_type, str):
47+
self.prediction_type = parse_string_types_instead_of_actual_objects(
48+
self.prediction_type
49+
)
50+
51+
@classmethod
52+
def process_data_after_load(cls, data):
53+
if "prediction_type" in data:
54+
data["prediction_type"] = parse_type_string(data["prediction_type"])
55+
return data
56+
57+
def process_data_before_dump(self, data):
58+
if "prediction_type" in data:
59+
if not isinstance(data["prediction_type"], str):
60+
data["prediction_type"] = to_type_string(data["prediction_type"])
61+
return data
62+
63+
def _add_score_prefix(self, score_name):
64+
return (
65+
self.score_prefix + score_name
66+
if score_name not in ["score", "score_name", "num_of_instances"]
67+
else score_name
68+
)
69+
70+
def _add_score_prefixes_to_score_dict_and_check_against_existing_scores(
71+
self, scores: Dict[str, Any], existing_scores: Dict[str, Any]
72+
) -> Dict[str, Any]:
73+
new_scores = {}
74+
for score_name, score in scores.items():
75+
score_with_prefix = self._add_score_prefix(score_name)
76+
new_scores[score_with_prefix] = (
77+
score if score_name not in ["score_name"] else self.score_prefix + score
78+
)
79+
for new_score_name in new_scores:
80+
if new_score_name in ["score", "score_name", "num_of_instances"]:
81+
continue
82+
if new_score_name in existing_scores:
83+
UnitxtWarning(
84+
message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
85+
f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
86+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
87+
f"which will yield, in this case, a score named: 'my_second_{new_score_name}')",
88+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
89+
)
90+
return new_scores
91+
92+
def _validate_references_and_prediction(self, references, predictions):
93+
if not isoftype(predictions, List[Any]):
94+
raise ValueError(
95+
f"Metric {self.get_metric_name()} should receive a list of predictions {self.get_metric_name()}. Received predictions of type {type(predictions)}: {predictions}"
96+
)
97+
98+
if not isoftype(references, List[Any]):
99+
raise ValueError(
100+
f"Metric {self.get_metric_name()} should receive a list of predictions. Received references of type {type(references)}: {references}"
101+
)
102+
103+
if len(references) != len(predictions):
104+
raise ValueError(
105+
f"references size ({len(references)})"
106+
f" doesn't mach predictions size ({len(references)})."
107+
)
108+
109+
for reference in references:
110+
self._validate_reference(reference)
111+
112+
for prediction in predictions:
113+
self._validate_prediction(prediction)
114+
115+
def _validate_prediction(self, prediction):
116+
if not isoftype(prediction, self.prediction_type):
117+
raise ValueError(
118+
f"Each prediction is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received prediction of type {type(prediction)}: {prediction}"
119+
)
120+
121+
def _validate_reference(self, reference):
122+
if not isoftype(reference, List[Any]):
123+
raise ValueError(
124+
f"Expecting a list of references for each prediction in {self.get_metric_name()} metric. Received reference of type {type(reference)}: {reference}"
125+
)
126+
if self.single_reference_per_prediction and not len(reference) == 1:
127+
raise ValueError(
128+
f"Expecting a list with a single reference per prediction in {self.get_metric_name()} metric. Received a list with multiple references: {reference}"
129+
)
130+
for ref in reference:
131+
if not isoftype(ref, self.prediction_type):
132+
raise ValueError(
133+
f"Each reference is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received reference of type {type(ref)}: {ref}"
134+
)
135+
136+
def get_metric_name(self):
137+
if self.__id__ is not None:
138+
return self.__id__
139+
return self.__class__.__name__
140+
141+
def consume_stream(self, stream: Stream):
142+
references = []
143+
predictions = []
144+
additional_inputs = []
145+
instances = []
146+
for instance in stream:
147+
instance = self.verify_instance(instance)
148+
references.append(instance["references"])
149+
predictions.append(instance["prediction"])
150+
additional_inputs.append(
151+
instance["additional_inputs"] if "additional_inputs" in instance else {}
152+
)
153+
instances.append(instance)
154+
return predictions, references, additional_inputs, instances
155+
156+
@staticmethod
157+
def update_instance_scores(instances, instances_scores: List[Dict[str, Any]]):
158+
for instance, new_scores in zip(instances, instances_scores):
159+
if "score" not in instance:
160+
instance["score"] = {}
161+
scores = instance["score"]
162+
if "instance" not in scores:
163+
scores["instance"] = {}
164+
scores["instance"].update(new_scores)
165+
166+
@staticmethod
167+
def set_global_score(instances, global_score: Dict[str, Any]):
168+
for instance in instances:
169+
if "score" not in instance:
170+
instance["score"] = {}
171+
scores = instance["score"]
172+
if "global" not in scores:
173+
scores["global"] = {}
174+
scores["global"] = global_score
175+
176+
@abstractmethod
177+
def disable_confidence_interval_calculation(self):
178+
pass
179+
180+
# update instance["score"]["global"] with the global_score just computed for the
181+
# current metric. global_score contains "score" and "score_name" fields that reflect
182+
# (the main_score of) the current metric. If CI was computed for global_score, then global_score
183+
# also contains "score_ci_low" and "score_ci_high" that reflect (the main_score of) the current metric.
184+
# A simple python-dictionary-update adds new fields to instance["score"]["global"], and also replaces the values
185+
# of its fields "score" and "score_name" (and "score_ci_low", "score_ci_high" if applicable),
186+
# to reflect the current metric, overwriting previous metrics' settings of these fields
187+
# (if any previous metric exists).
188+
# When global_score does NOT contain ci score (because CI was not computed for the current metric), but
189+
# one of the previous metrics computed did have, the last of such previous metrics set the values in
190+
# fields "score_ci_low" and "score_ci_high" in instance["score"]["global"] to reflect its
191+
# (the previous metric's) CI scores.
192+
# Because CI is not computed for the current metric, global_score does not contain fields "score_ci_low" and
193+
# "score_ci_high" to overwrite the ones existing in instance["score"]["global"], and these might remain in
194+
# instance["score"]["global"], but their values, that are not associated with the current metric, are,
195+
# therefore, not consistent with "score_name".
196+
# In such a case, following the python-dictionary-update, we pop out fields "score_ci_low" and
197+
# "score_ci_high" from instance["score"]["global"], so that now all the fields "score.." in
198+
# instance["score"]["global"] are consistent with the current metric: The metric that is named
199+
# instance["score"]["global"]["score_name"], its score shows in
200+
# field instance["score"]["global"]["score"], and it does not have ci_scores,
201+
# which is also reflected in the absence of fields "score_ci_low" and "score_ci_high" from instance["score"]["global"].
202+
# If ci IS computed for the current metric, global_score contains "score_ci_low" and "score_ci_high", and these overwrite
203+
# the ones existing in instance["score"]["global"] by the simple python-dictionary-update, and no need for any further fixeup.
204+
def update_and_adjust_global_score(
205+
self, instance: Dict[str, Any], global_score: dict
206+
):
207+
for score_name in global_score:
208+
if score_name in [
209+
"score",
210+
"score_name",
211+
"score_ci_low",
212+
"score_ci_high",
213+
"num_of_instances",
214+
]:
215+
continue
216+
if score_name in instance["score"]["global"]:
217+
UnitxtWarning(
218+
message=f"Global metric '{score_name}' that has just been evaluated to {global_score[score_name]}, is already recorded "
219+
f"to have value {instance['score']['global'][score_name]} by a previous metric evaluation on this stream. "
220+
f"To avoid overwriting the value, add a score_prefix to the metric (e.g. score_prefix='my_{score_name}'.",
221+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
222+
)
223+
instance["score"]["global"].update(global_score)
224+
for score_ci in ["score_ci_low", "score_ci_high"]:
225+
if score_ci in global_score:
226+
continue
227+
if score_ci in instance["score"]["global"]:
228+
instance["score"]["global"].pop(score_ci)
229+

0 commit comments

Comments
 (0)