|
| 1 | +from abc import abstractmethod |
| 2 | +from typing import ( |
| 3 | + Any, |
| 4 | + Dict, |
| 5 | + List, |
| 6 | + Union, |
| 7 | +) |
| 8 | + |
| 9 | +from .artifact import Artifact |
| 10 | +from .dataclass import ( |
| 11 | + AbstractField, |
| 12 | +) |
| 13 | +from .deprecation_utils import deprecation |
| 14 | +from .error_utils import Documentation, UnitxtWarning |
| 15 | +from .stream import Stream |
| 16 | +from .type_utils import Type, isoftype, parse_type_string, to_type_string |
| 17 | + |
| 18 | + |
| 19 | +@deprecation( |
| 20 | + version="2.0.0", |
| 21 | + msg="use regular type instead of strings (e.g Dict[str] instead of 'Dict[str]')", |
| 22 | +) |
| 23 | +def parse_string_types_instead_of_actual_objects(obj): |
| 24 | + return parse_type_string(obj) |
| 25 | + |
| 26 | +class Metric(Artifact): |
| 27 | + main_score: str = AbstractField() |
| 28 | + # Override 'prediction_type' with the expected type of predictions |
| 29 | + # and references. Example: "List[str]", "List[Dict]"", "string". |
| 30 | + # If left with default None, a warning will be displayed. |
| 31 | + # In future versions of unitxt, this will be an error. |
| 32 | + prediction_type: Union[Type, str] = Any |
| 33 | + |
| 34 | + # Standard metrics can receive multiple references per predictions (in a list) |
| 35 | + # Some metrics support only a single reference per prediction (one element in the list) |
| 36 | + single_reference_per_prediction: bool = False |
| 37 | + |
| 38 | + # |
| 39 | + # Used to add a prefix to all score, except the "score_name" and "score" fields. |
| 40 | + # This is used to distinguish two scores of the same metrics, operating on different fields of the task |
| 41 | + # |
| 42 | + score_prefix: str = "" |
| 43 | + |
| 44 | + def prepare_args(self): |
| 45 | + super().prepare_args() |
| 46 | + if isinstance(self.prediction_type, str): |
| 47 | + self.prediction_type = parse_string_types_instead_of_actual_objects( |
| 48 | + self.prediction_type |
| 49 | + ) |
| 50 | + |
| 51 | + @classmethod |
| 52 | + def process_data_after_load(cls, data): |
| 53 | + if "prediction_type" in data: |
| 54 | + data["prediction_type"] = parse_type_string(data["prediction_type"]) |
| 55 | + return data |
| 56 | + |
| 57 | + def process_data_before_dump(self, data): |
| 58 | + if "prediction_type" in data: |
| 59 | + if not isinstance(data["prediction_type"], str): |
| 60 | + data["prediction_type"] = to_type_string(data["prediction_type"]) |
| 61 | + return data |
| 62 | + |
| 63 | + def _add_score_prefix(self, score_name): |
| 64 | + return ( |
| 65 | + self.score_prefix + score_name |
| 66 | + if score_name not in ["score", "score_name", "num_of_instances"] |
| 67 | + else score_name |
| 68 | + ) |
| 69 | + |
| 70 | + def _add_score_prefixes_to_score_dict_and_check_against_existing_scores( |
| 71 | + self, scores: Dict[str, Any], existing_scores: Dict[str, Any] |
| 72 | + ) -> Dict[str, Any]: |
| 73 | + new_scores = {} |
| 74 | + for score_name, score in scores.items(): |
| 75 | + score_with_prefix = self._add_score_prefix(score_name) |
| 76 | + new_scores[score_with_prefix] = ( |
| 77 | + score if score_name not in ["score_name"] else self.score_prefix + score |
| 78 | + ) |
| 79 | + for new_score_name in new_scores: |
| 80 | + if new_score_name in ["score", "score_name", "num_of_instances"]: |
| 81 | + continue |
| 82 | + if new_score_name in existing_scores: |
| 83 | + UnitxtWarning( |
| 84 | + message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded " |
| 85 | + f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. " |
| 86 | + f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , " |
| 87 | + f"which will yield, in this case, a score named: 'my_second_{new_score_name}')", |
| 88 | + additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS, |
| 89 | + ) |
| 90 | + return new_scores |
| 91 | + |
| 92 | + def _validate_references_and_prediction(self, references, predictions): |
| 93 | + if not isoftype(predictions, List[Any]): |
| 94 | + raise ValueError( |
| 95 | + f"Metric {self.get_metric_name()} should receive a list of predictions {self.get_metric_name()}. Received predictions of type {type(predictions)}: {predictions}" |
| 96 | + ) |
| 97 | + |
| 98 | + if not isoftype(references, List[Any]): |
| 99 | + raise ValueError( |
| 100 | + f"Metric {self.get_metric_name()} should receive a list of predictions. Received references of type {type(references)}: {references}" |
| 101 | + ) |
| 102 | + |
| 103 | + if len(references) != len(predictions): |
| 104 | + raise ValueError( |
| 105 | + f"references size ({len(references)})" |
| 106 | + f" doesn't mach predictions size ({len(references)})." |
| 107 | + ) |
| 108 | + |
| 109 | + for reference in references: |
| 110 | + self._validate_reference(reference) |
| 111 | + |
| 112 | + for prediction in predictions: |
| 113 | + self._validate_prediction(prediction) |
| 114 | + |
| 115 | + def _validate_prediction(self, prediction): |
| 116 | + if not isoftype(prediction, self.prediction_type): |
| 117 | + raise ValueError( |
| 118 | + f"Each prediction is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received prediction of type {type(prediction)}: {prediction}" |
| 119 | + ) |
| 120 | + |
| 121 | + def _validate_reference(self, reference): |
| 122 | + if not isoftype(reference, List[Any]): |
| 123 | + raise ValueError( |
| 124 | + f"Expecting a list of references for each prediction in {self.get_metric_name()} metric. Received reference of type {type(reference)}: {reference}" |
| 125 | + ) |
| 126 | + if self.single_reference_per_prediction and not len(reference) == 1: |
| 127 | + raise ValueError( |
| 128 | + f"Expecting a list with a single reference per prediction in {self.get_metric_name()} metric. Received a list with multiple references: {reference}" |
| 129 | + ) |
| 130 | + for ref in reference: |
| 131 | + if not isoftype(ref, self.prediction_type): |
| 132 | + raise ValueError( |
| 133 | + f"Each reference is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received reference of type {type(ref)}: {ref}" |
| 134 | + ) |
| 135 | + |
| 136 | + def get_metric_name(self): |
| 137 | + if self.__id__ is not None: |
| 138 | + return self.__id__ |
| 139 | + return self.__class__.__name__ |
| 140 | + |
| 141 | + def consume_stream(self, stream: Stream): |
| 142 | + references = [] |
| 143 | + predictions = [] |
| 144 | + additional_inputs = [] |
| 145 | + instances = [] |
| 146 | + for instance in stream: |
| 147 | + instance = self.verify_instance(instance) |
| 148 | + references.append(instance["references"]) |
| 149 | + predictions.append(instance["prediction"]) |
| 150 | + additional_inputs.append( |
| 151 | + instance["additional_inputs"] if "additional_inputs" in instance else {} |
| 152 | + ) |
| 153 | + instances.append(instance) |
| 154 | + return predictions, references, additional_inputs, instances |
| 155 | + |
| 156 | + @staticmethod |
| 157 | + def update_instance_scores(instances, instances_scores: List[Dict[str, Any]]): |
| 158 | + for instance, new_scores in zip(instances, instances_scores): |
| 159 | + if "score" not in instance: |
| 160 | + instance["score"] = {} |
| 161 | + scores = instance["score"] |
| 162 | + if "instance" not in scores: |
| 163 | + scores["instance"] = {} |
| 164 | + scores["instance"].update(new_scores) |
| 165 | + |
| 166 | + @staticmethod |
| 167 | + def set_global_score(instances, global_score: Dict[str, Any]): |
| 168 | + for instance in instances: |
| 169 | + if "score" not in instance: |
| 170 | + instance["score"] = {} |
| 171 | + scores = instance["score"] |
| 172 | + if "global" not in scores: |
| 173 | + scores["global"] = {} |
| 174 | + scores["global"] = global_score |
| 175 | + |
| 176 | + @abstractmethod |
| 177 | + def disable_confidence_interval_calculation(self): |
| 178 | + pass |
| 179 | + |
| 180 | + # update instance["score"]["global"] with the global_score just computed for the |
| 181 | + # current metric. global_score contains "score" and "score_name" fields that reflect |
| 182 | + # (the main_score of) the current metric. If CI was computed for global_score, then global_score |
| 183 | + # also contains "score_ci_low" and "score_ci_high" that reflect (the main_score of) the current metric. |
| 184 | + # A simple python-dictionary-update adds new fields to instance["score"]["global"], and also replaces the values |
| 185 | + # of its fields "score" and "score_name" (and "score_ci_low", "score_ci_high" if applicable), |
| 186 | + # to reflect the current metric, overwriting previous metrics' settings of these fields |
| 187 | + # (if any previous metric exists). |
| 188 | + # When global_score does NOT contain ci score (because CI was not computed for the current metric), but |
| 189 | + # one of the previous metrics computed did have, the last of such previous metrics set the values in |
| 190 | + # fields "score_ci_low" and "score_ci_high" in instance["score"]["global"] to reflect its |
| 191 | + # (the previous metric's) CI scores. |
| 192 | + # Because CI is not computed for the current metric, global_score does not contain fields "score_ci_low" and |
| 193 | + # "score_ci_high" to overwrite the ones existing in instance["score"]["global"], and these might remain in |
| 194 | + # instance["score"]["global"], but their values, that are not associated with the current metric, are, |
| 195 | + # therefore, not consistent with "score_name". |
| 196 | + # In such a case, following the python-dictionary-update, we pop out fields "score_ci_low" and |
| 197 | + # "score_ci_high" from instance["score"]["global"], so that now all the fields "score.." in |
| 198 | + # instance["score"]["global"] are consistent with the current metric: The metric that is named |
| 199 | + # instance["score"]["global"]["score_name"], its score shows in |
| 200 | + # field instance["score"]["global"]["score"], and it does not have ci_scores, |
| 201 | + # which is also reflected in the absence of fields "score_ci_low" and "score_ci_high" from instance["score"]["global"]. |
| 202 | + # If ci IS computed for the current metric, global_score contains "score_ci_low" and "score_ci_high", and these overwrite |
| 203 | + # the ones existing in instance["score"]["global"] by the simple python-dictionary-update, and no need for any further fixeup. |
| 204 | + def update_and_adjust_global_score( |
| 205 | + self, instance: Dict[str, Any], global_score: dict |
| 206 | + ): |
| 207 | + for score_name in global_score: |
| 208 | + if score_name in [ |
| 209 | + "score", |
| 210 | + "score_name", |
| 211 | + "score_ci_low", |
| 212 | + "score_ci_high", |
| 213 | + "num_of_instances", |
| 214 | + ]: |
| 215 | + continue |
| 216 | + if score_name in instance["score"]["global"]: |
| 217 | + UnitxtWarning( |
| 218 | + message=f"Global metric '{score_name}' that has just been evaluated to {global_score[score_name]}, is already recorded " |
| 219 | + f"to have value {instance['score']['global'][score_name]} by a previous metric evaluation on this stream. " |
| 220 | + f"To avoid overwriting the value, add a score_prefix to the metric (e.g. score_prefix='my_{score_name}'.", |
| 221 | + additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS, |
| 222 | + ) |
| 223 | + instance["score"]["global"].update(global_score) |
| 224 | + for score_ci in ["score_ci_low", "score_ci_high"]: |
| 225 | + if score_ci in global_score: |
| 226 | + continue |
| 227 | + if score_ci in instance["score"]["global"]: |
| 228 | + instance["score"]["global"].pop(score_ci) |
| 229 | + |
0 commit comments