Skip to content

Commit 48cd7b1

Browse files
committed
Created lighteval tasks module to shim with LightEval, with tests
1 parent 9a0bb89 commit 48cd7b1

File tree

2 files changed

+262
-0
lines changed

2 files changed

+262
-0
lines changed

src/ether0/lighteval_tasks.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import logging
2+
import statistics
3+
from collections.abc import Collection, Iterable
4+
from typing import Any
5+
6+
from datasets import load_dataset
7+
8+
try:
9+
from aenum import extend_enum
10+
from lighteval.metrics.metrics import (
11+
MetricCategory,
12+
Metrics,
13+
MetricUseCase,
14+
SampleLevelMetric,
15+
)
16+
from lighteval.tasks.lighteval_task import LightevalTaskConfig
17+
from lighteval.tasks.requests import Doc
18+
except ImportError as exc:
19+
raise ImportError(
20+
"To use ether0's LightEval tasks, please install the 'lighteval' extra via:"
21+
" `pip install ether0[lighteval]`."
22+
) from exc
23+
24+
from ether0.data import get_problem_category
25+
from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, ProblemPrompt
26+
from ether0.models import make_problem_types_filter
27+
from ether0.rewards import accuracy_reward, format_reward
28+
29+
logger = logging.getLogger(__name__)
30+
31+
ETHER0_ACCURACY_METRIC_NAME = "ether0_accuracy"
32+
ETHER0_FORMAT_METRIC_NAME = "ether0_format"
33+
34+
35+
def evaluate_ether0_accuracy(
36+
predictions: list[str],
37+
formatted_doc: Doc,
38+
golds: list[str] | None = None, # noqa: ARG001
39+
) -> float:
40+
if len(predictions) != 1:
41+
raise NotImplementedError(
42+
"Didn't handle anything besides one prediction"
43+
f" for doc {formatted_doc}, got {predictions}."
44+
)
45+
return accuracy_reward(
46+
completions=predictions,
47+
solution=[formatted_doc.specific["solution"]],
48+
reasoning=formatted_doc.specific["reasoning"],
49+
soft=formatted_doc.specific["soft"],
50+
test=formatted_doc.specific["test"],
51+
)[0]
52+
53+
54+
def evaluate_ether0_format(
55+
predictions: list[str],
56+
formatted_doc: Doc,
57+
golds: list[str] | None = None, # noqa: ARG001
58+
) -> float:
59+
if len(predictions) != 1:
60+
raise NotImplementedError(
61+
"Didn't handle anything besides one prediction"
62+
f" for doc {formatted_doc}, got {predictions}."
63+
)
64+
if formatted_doc.specific["test"]:
65+
logger.warning("ether0's format reward is only applicable at training time.")
66+
return format_reward(
67+
completions=predictions,
68+
reasoning=formatted_doc.specific["reasoning"],
69+
)[0]
70+
71+
72+
for metric_name, metric_eval_fn in (
73+
(ETHER0_ACCURACY_METRIC_NAME, evaluate_ether0_accuracy),
74+
(ETHER0_FORMAT_METRIC_NAME, evaluate_ether0_format),
75+
):
76+
if ( # Work around https://github.com/huggingface/lighteval/issues/805
77+
metric_name not in Metrics.__members__
78+
):
79+
extend_enum(
80+
Metrics,
81+
metric_name,
82+
SampleLevelMetric(
83+
metric_name=metric_name,
84+
higher_is_better=True,
85+
category=MetricCategory.GENERATIVE,
86+
use_case=MetricUseCase.ACCURACY,
87+
sample_level_fn=metric_eval_fn,
88+
corpus_level_fn=statistics.mean,
89+
),
90+
)
91+
92+
93+
KEYS_TO_STORE_IN_DOC = {"id", "solution"}
94+
95+
96+
def make_ether0_task(
97+
name: str,
98+
soft: bool,
99+
test: bool,
100+
reasoning: bool,
101+
problem_types: str | Collection[str] | None = None,
102+
metric_names: Iterable[str] | None = None,
103+
**kwargs,
104+
) -> LightevalTaskConfig:
105+
"""Create LightEval task for the ether0-benchmark dataset."""
106+
reward_fn_kwargs = {"soft": soft, "test": test, "reasoning": reasoning}
107+
if not test:
108+
prob_prompt = ProblemPrompt.THINK_ANSWER if reasoning else ProblemPrompt.ANSWER
109+
prompt_prefix: str = prob_prompt.get_prompt()
110+
else:
111+
prompt_prefix = LOOSE_XML_ANSWER_USER_PROMPT
112+
113+
def row_to_doc(row: dict[str, Any], task_name: str) -> Doc:
114+
"""Convert an ether0-benchmark dataset row to a LightEval Doc."""
115+
return Doc(
116+
query="\n\n".join((prompt_prefix, row["problem"])),
117+
task_name=task_name,
118+
choices=[""], # Placeholder for non-QA tasks
119+
gold_index=0, # Points to above placeholder
120+
specific={k: row[k] for k in KEYS_TO_STORE_IN_DOC} | reward_fn_kwargs,
121+
)
122+
123+
if metric_names is None:
124+
metric_names = (
125+
(ETHER0_ACCURACY_METRIC_NAME, ETHER0_FORMAT_METRIC_NAME)
126+
if not test
127+
else (ETHER0_ACCURACY_METRIC_NAME,)
128+
)
129+
return LightevalTaskConfig(
130+
name=name,
131+
prompt_function=row_to_doc,
132+
suite=["community"],
133+
hf_repo="futurehouse/ether0-benchmark",
134+
hf_subset="default",
135+
hf_filter=(
136+
make_problem_types_filter(problem_types, type_col="problem_type")
137+
if problem_types is not None
138+
else None
139+
),
140+
hf_avail_splits=["test"],
141+
evaluation_splits=["test"],
142+
metric=[getattr(Metrics, metric_name) for metric_name in metric_names],
143+
**kwargs,
144+
)
145+
146+
147+
# TASKS_TABLE is required by LightEval for --custom-tasks CLI arg
148+
TASKS_TABLE = [ # Add general tasks
149+
make_ether0_task(
150+
f"ether0:{nickname}{':soft' if is_soft else ''}",
151+
soft=is_soft,
152+
test=kwargs["test"],
153+
reasoning=kwargs["reasoning"],
154+
)
155+
for is_soft in (False, True)
156+
for nickname, kwargs in (
157+
("loose", {"test": True, "reasoning": False}),
158+
("strict:no_reasoning", {"test": False, "reasoning": False}),
159+
("strict", {"test": False, "reasoning": True}),
160+
)
161+
]
162+
TASKS_TABLE.extend([ # Add problem type-specific tasks
163+
make_ether0_task(
164+
f"ether0:{nickname}{':soft' if is_soft else ''}:{prob_cat}",
165+
soft=is_soft,
166+
test=kwargs["test"],
167+
reasoning=kwargs["reasoning"],
168+
problem_types=f"re:^{prob_cat}.*$",
169+
)
170+
for is_soft in (False, True)
171+
for nickname, kwargs in (
172+
("loose", {"test": True, "reasoning": False}),
173+
("strict:no_reasoning", {"test": False, "reasoning": False}),
174+
("strict", {"test": False, "reasoning": True}),
175+
)
176+
for prob_cat in {
177+
get_problem_category(pt)
178+
for pt in load_dataset("futurehouse/ether0-benchmark", split="test")[
179+
"problem_type"
180+
]
181+
}
182+
])

tests/test_lighteval_tasks.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from unittest.mock import patch
2+
3+
from lighteval.main_tasks import list as lighteval_list
4+
from lighteval.metrics.metrics import Metrics, SampleLevelMetric
5+
from lighteval.tasks.requests import Doc
6+
7+
import ether0.lighteval_tasks
8+
9+
10+
def test_task_list(capsys) -> None:
11+
"""Integration test designed to test TASKS_TABLE and custom task creation."""
12+
with patch( # Work around https://github.com/huggingface/lighteval/issues/805
13+
"lighteval.tasks.registry.create_custom_tasks_module",
14+
side_effect=[ether0.lighteval_tasks],
15+
):
16+
lighteval_list(custom_tasks=ether0.lighteval_tasks.__file__)
17+
captured = capsys.readouterr()
18+
assert not captured.err
19+
tasks = [row for row in captured.out.splitlines() if "ether0" in row]
20+
assert len(tasks) > 1, "Expected some ether0 tasks"
21+
assert any(
22+
"functional-group" in row for row in tasks
23+
), "Expected specific tasks to be listed"
24+
# TODO: after https://github.com/huggingface/lighteval/issues/806,
25+
# remove the .litellm_cache directory created by this test
26+
27+
28+
def test_accuracy_metric() -> None:
29+
accuracy_metric = getattr(
30+
Metrics, ether0.lighteval_tasks.ETHER0_ACCURACY_METRIC_NAME
31+
).value
32+
assert isinstance(accuracy_metric, SampleLevelMetric)
33+
34+
# NOTE: these inputs were taken from a gpt-4o baseline run
35+
doc_json = {
36+
"query": (
37+
"When answering, be sure to place the final answer as SMILES notation into"
38+
" XML tags <answer></answer>. An example is <answer>CCO</answer>.\n\nWhat"
39+
" is a valid completion of this molecule:\nO=C(OCC1=CC=CC=C1)N1CCCC1C(=O"
40+
),
41+
"choices": [""],
42+
"gold_index": 0,
43+
"original_query": "",
44+
"specific": {
45+
"solution": (
46+
"valid_mol_eval!:!O=C(OCC1=CC=CC=C1)N1CCCC1C(=O!:!molecule-completion"
47+
),
48+
"id": "e8b8bb34-731a-46e1-93a2-b6330a705148",
49+
"soft": False,
50+
"test": True,
51+
"reasoning": False,
52+
},
53+
"task_name": "community|ether0:loose:molecule-completion",
54+
"instruction": "",
55+
"ctx": [{
56+
"role": "user",
57+
"content": (
58+
"When answering, be sure to place the final answer as SMILES notation"
59+
" into XML tags <answer></answer>. An example is"
60+
" <answer>CCO</answer>.\n\nWhat is a valid completion of this"
61+
" molecule:\nO=C(OCC1=CC=CC=C1)N1CCCC1C(=O"
62+
),
63+
}],
64+
"num_asked_few_shots": 0,
65+
"num_effective_few_shots": 0,
66+
}
67+
assert (
68+
accuracy_metric.sample_level_fn(
69+
predictions=[
70+
"The given fragment of the molecule O=C(OCC1=CC=CC=C1)N1CCCC1C(=O suggests"
71+
" a structure that indicates an amide linkage with a substituted"
72+
" cyclohexanone. A plausible completion of this structure is a standard"
73+
" cyclohexanone amide. Therefore, a valid SMILES notation for the completed"
74+
" structure is:\n\n<answer>O=C(OCC1=CC=CC=C1)N1CCCC1C(=O)C2CCCCC2</answer>"
75+
],
76+
formatted_doc=Doc(**doc_json),
77+
golds=[""],
78+
)
79+
== 1.0
80+
)

0 commit comments

Comments
 (0)