|
| 1 | +import logging |
| 2 | +import statistics |
| 3 | +from collections.abc import Collection, Iterable |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +from datasets import load_dataset |
| 7 | + |
| 8 | +try: |
| 9 | + from aenum import extend_enum |
| 10 | + from lighteval.metrics.metrics import ( |
| 11 | + MetricCategory, |
| 12 | + Metrics, |
| 13 | + MetricUseCase, |
| 14 | + SampleLevelMetric, |
| 15 | + ) |
| 16 | + from lighteval.tasks.lighteval_task import LightevalTaskConfig |
| 17 | + from lighteval.tasks.requests import Doc |
| 18 | +except ImportError as exc: |
| 19 | + raise ImportError( |
| 20 | + "To use ether0's LightEval tasks, please install the 'lighteval' extra via:" |
| 21 | + " `pip install ether0[lighteval]`." |
| 22 | + ) from exc |
| 23 | + |
| 24 | +from ether0.data import get_problem_category |
| 25 | +from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, ProblemPrompt |
| 26 | +from ether0.models import make_problem_types_filter |
| 27 | +from ether0.rewards import accuracy_reward, format_reward |
| 28 | + |
| 29 | +logger = logging.getLogger(__name__) |
| 30 | + |
| 31 | +ETHER0_ACCURACY_METRIC_NAME = "ether0_accuracy" |
| 32 | +ETHER0_FORMAT_METRIC_NAME = "ether0_format" |
| 33 | + |
| 34 | + |
| 35 | +def evaluate_ether0_accuracy( |
| 36 | + predictions: list[str], |
| 37 | + formatted_doc: Doc, |
| 38 | + golds: list[str] | None = None, # noqa: ARG001 |
| 39 | +) -> float: |
| 40 | + if len(predictions) != 1: |
| 41 | + raise NotImplementedError( |
| 42 | + "Didn't handle anything besides one prediction" |
| 43 | + f" for doc {formatted_doc}, got {predictions}." |
| 44 | + ) |
| 45 | + return accuracy_reward( |
| 46 | + completions=predictions, |
| 47 | + solution=[formatted_doc.specific["solution"]], |
| 48 | + reasoning=formatted_doc.specific["reasoning"], |
| 49 | + soft=formatted_doc.specific["soft"], |
| 50 | + test=formatted_doc.specific["test"], |
| 51 | + )[0] |
| 52 | + |
| 53 | + |
| 54 | +def evaluate_ether0_format( |
| 55 | + predictions: list[str], |
| 56 | + formatted_doc: Doc, |
| 57 | + golds: list[str] | None = None, # noqa: ARG001 |
| 58 | +) -> float: |
| 59 | + if len(predictions) != 1: |
| 60 | + raise NotImplementedError( |
| 61 | + "Didn't handle anything besides one prediction" |
| 62 | + f" for doc {formatted_doc}, got {predictions}." |
| 63 | + ) |
| 64 | + if formatted_doc.specific["test"]: |
| 65 | + logger.warning("ether0's format reward is only applicable at training time.") |
| 66 | + return format_reward( |
| 67 | + completions=predictions, |
| 68 | + reasoning=formatted_doc.specific["reasoning"], |
| 69 | + )[0] |
| 70 | + |
| 71 | + |
| 72 | +for metric_name, metric_eval_fn in ( |
| 73 | + (ETHER0_ACCURACY_METRIC_NAME, evaluate_ether0_accuracy), |
| 74 | + (ETHER0_FORMAT_METRIC_NAME, evaluate_ether0_format), |
| 75 | +): |
| 76 | + if ( # Work around https://github.com/huggingface/lighteval/issues/805 |
| 77 | + metric_name not in Metrics.__members__ |
| 78 | + ): |
| 79 | + extend_enum( |
| 80 | + Metrics, |
| 81 | + metric_name, |
| 82 | + SampleLevelMetric( |
| 83 | + metric_name=metric_name, |
| 84 | + higher_is_better=True, |
| 85 | + category=MetricCategory.GENERATIVE, |
| 86 | + use_case=MetricUseCase.ACCURACY, |
| 87 | + sample_level_fn=metric_eval_fn, |
| 88 | + corpus_level_fn=statistics.mean, |
| 89 | + ), |
| 90 | + ) |
| 91 | + |
| 92 | + |
| 93 | +KEYS_TO_STORE_IN_DOC = {"id", "solution"} |
| 94 | + |
| 95 | + |
| 96 | +def make_ether0_task( |
| 97 | + name: str, |
| 98 | + soft: bool, |
| 99 | + test: bool, |
| 100 | + reasoning: bool, |
| 101 | + problem_types: str | Collection[str] | None = None, |
| 102 | + metric_names: Iterable[str] | None = None, |
| 103 | + **kwargs, |
| 104 | +) -> LightevalTaskConfig: |
| 105 | + """Create LightEval task for the ether0-benchmark dataset.""" |
| 106 | + reward_fn_kwargs = {"soft": soft, "test": test, "reasoning": reasoning} |
| 107 | + if not test: |
| 108 | + prob_prompt = ProblemPrompt.THINK_ANSWER if reasoning else ProblemPrompt.ANSWER |
| 109 | + prompt_prefix: str = prob_prompt.get_prompt() |
| 110 | + else: |
| 111 | + prompt_prefix = LOOSE_XML_ANSWER_USER_PROMPT |
| 112 | + |
| 113 | + def row_to_doc(row: dict[str, Any], task_name: str) -> Doc: |
| 114 | + """Convert an ether0-benchmark dataset row to a LightEval Doc.""" |
| 115 | + return Doc( |
| 116 | + query="\n\n".join((prompt_prefix, row["problem"])), |
| 117 | + task_name=task_name, |
| 118 | + choices=[""], # Placeholder for non-QA tasks |
| 119 | + gold_index=0, # Points to above placeholder |
| 120 | + specific={k: row[k] for k in KEYS_TO_STORE_IN_DOC} | reward_fn_kwargs, |
| 121 | + ) |
| 122 | + |
| 123 | + if metric_names is None: |
| 124 | + metric_names = ( |
| 125 | + (ETHER0_ACCURACY_METRIC_NAME, ETHER0_FORMAT_METRIC_NAME) |
| 126 | + if not test |
| 127 | + else (ETHER0_ACCURACY_METRIC_NAME,) |
| 128 | + ) |
| 129 | + return LightevalTaskConfig( |
| 130 | + name=name, |
| 131 | + prompt_function=row_to_doc, |
| 132 | + suite=["community"], |
| 133 | + hf_repo="futurehouse/ether0-benchmark", |
| 134 | + hf_subset="default", |
| 135 | + hf_filter=( |
| 136 | + make_problem_types_filter(problem_types, type_col="problem_type") |
| 137 | + if problem_types is not None |
| 138 | + else None |
| 139 | + ), |
| 140 | + hf_avail_splits=["test"], |
| 141 | + evaluation_splits=["test"], |
| 142 | + metric=[getattr(Metrics, metric_name) for metric_name in metric_names], |
| 143 | + **kwargs, |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +# TASKS_TABLE is required by LightEval for --custom-tasks CLI arg |
| 148 | +TASKS_TABLE = [ # Add general tasks |
| 149 | + make_ether0_task( |
| 150 | + f"ether0:{nickname}{':soft' if is_soft else ''}", |
| 151 | + soft=is_soft, |
| 152 | + test=kwargs["test"], |
| 153 | + reasoning=kwargs["reasoning"], |
| 154 | + ) |
| 155 | + for is_soft in (False, True) |
| 156 | + for nickname, kwargs in ( |
| 157 | + ("loose", {"test": True, "reasoning": False}), |
| 158 | + ("strict:no_reasoning", {"test": False, "reasoning": False}), |
| 159 | + ("strict", {"test": False, "reasoning": True}), |
| 160 | + ) |
| 161 | +] |
| 162 | +TASKS_TABLE.extend([ # Add problem type-specific tasks |
| 163 | + make_ether0_task( |
| 164 | + f"ether0:{nickname}{':soft' if is_soft else ''}:{prob_cat}", |
| 165 | + soft=is_soft, |
| 166 | + test=kwargs["test"], |
| 167 | + reasoning=kwargs["reasoning"], |
| 168 | + problem_types=f"re:^{prob_cat}.*$", |
| 169 | + ) |
| 170 | + for is_soft in (False, True) |
| 171 | + for nickname, kwargs in ( |
| 172 | + ("loose", {"test": True, "reasoning": False}), |
| 173 | + ("strict:no_reasoning", {"test": False, "reasoning": False}), |
| 174 | + ("strict", {"test": False, "reasoning": True}), |
| 175 | + ) |
| 176 | + for prob_cat in { |
| 177 | + get_problem_category(pt) |
| 178 | + for pt in load_dataset("futurehouse/ether0-benchmark", split="test")[ |
| 179 | + "problem_type" |
| 180 | + ] |
| 181 | + } |
| 182 | +]) |
0 commit comments