Skip to content

Commit f8277b1

Browse files
committed
Clean up RagasEvaluator interfaces
Refactor RagasEvaluator Class for use for `ilab` interface. Signed-off-by: Ali Maredia <[email protected]>
1 parent 207ecc6 commit f8277b1

File tree

1 file changed

+38
-58
lines changed

1 file changed

+38
-58
lines changed

src/instructlab/eval/ragas.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class ModelConfig(BaseModel):
7474
max_tokens: int = 768
7575

7676
# Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
77-
seed: int = DEFAULT_SEED
77+
seed: int = 42
7878

7979

8080
class RagasEvaluator(Evaluator):
@@ -96,29 +96,25 @@ def __init__(
9696
self.judge_openai_api_key = judge_openai_api_key
9797

9898
@staticmethod
99-
def _validate_dataset(df: DataFrame):
99+
def validate_dataset(df: DataFrame):
100100
"""
101101
Validates whether or not the given `df` is a valid dataset of `Sample` objects.
102102
103103
Args:
104-
df (DataFrame): DataFrame containing the dataset to be evaluated.
104+
df (DataFrame): DataFrame containing the dataset to be evaluated.
105105
"""
106-
# We have to hardcode these fields because the automated way of resolving the required fields from a TypedDict
107-
# is only included by default in Python3.11+. For earlier versions, the `typing_extensions` package is required.
108-
# See: https://docs.python.org/3/whatsnew/3.11.html#pep-655-marking-individual-typeddict-items-as-required-or-not-required
109-
required_keys = {"user_input", "reference"}
110-
missing_keys = required_keys - set(df.columns)
111-
if missing_keys:
106+
required_keys = {"user_input", "reference", "response"}
107+
108+
columns_list = set(df.columns)
109+
if not columns_list.issubset(required_keys):
112110
raise ValueError(
113-
f"invalid dataset provided, missing the following keys: {', '.join(missing_keys)}"
111+
f"Dataset can only have the following keys: {', '.join(required_keys)}. Keys provided were: {', '.join(columns_list)}"
114112
)
115113

116114
def run(
117115
self,
118-
dataset: List[Sample] | Path,
119-
student_model: ModelConfig | None = None,
116+
dataset: List[Sample] | DataFrame,
120117
run_config: RunConfig | None = None,
121-
student_openai_client: OpenAIClient | None = None,
122118
judge_model_name: str | None = None,
123119
judge_openai_api_key: str | None = None,
124120
) -> EvaluationResult:
@@ -132,17 +128,12 @@ def run(
132128
dataset (List[Sample] | Path):
133129
Can be either a list of `Sample` objects or a path to a jsonl file containing
134130
records matching `Sample`.
135-
student_model: (StudentModelConfig):
136-
When this parameter is provided, we'll attempt to use the described model in order to
137131
generate the responses from the given list of questions.
138132
run_config (RunConfig | None, optional):
139133
Configuration to use when running evaluations. If none is provided, then
140134
a default one is created containing extremely permissive settings when handling
141135
timeouts. This is because by default, OpenAI tier-1 usage accounts have very high
142136
rate limits resulting in heavy throttling during evaluations.
143-
student_openai_client (openai.Client | None, optional):
144-
The client to use when generating questions from the student model, must be compatible with the OpenAI API.
145-
This field is required when `student_model` is provided.
146137
judge_model_name (str | None, optional):
147138
Name of the OpenAI model to use as the judge model. Defaults to "gpt-4o" when none is specified.
148139
judge_openai_api_key (str | None, optional):
@@ -158,50 +149,29 @@ def run(
158149
judge_openai_api_key = (
159150
judge_openai_api_key if judge_openai_api_key else self.judge_openai_api_key
160151
)
161-
student_model = student_model if student_model else self.student_model
162152
run_config = run_config if run_config else self.run_config
163-
student_openai_client = (
164-
student_openai_client
165-
if student_openai_client
166-
else self.student_openai_client
167-
)
168153

169154
# ensure we are in the dataframe format
170-
input_df = None
155+
input_df = dataset
171156
if isinstance(dataset, list):
172157
input_df = DataFrame(dataset)
173-
elif isinstance(dataset, Path):
174-
input_df = read_json(dataset, orient="records", lines=True)
175-
else:
158+
elif not isinstance(dataset, DataFrame):
176159
raise TypeError(f"invalid type of dataset: {type(dataset)}")
177160

178161
# this should never happen, but pylint is not smart enough to detect it
179162
if TYPE_CHECKING:
180163
assert input_df is not None
181164

182165
# ensure the dataset is in the format we expect it
183-
self._validate_dataset(input_df)
184-
185-
need_to_generate_questions = "response" not in input_df.columns
186-
if need_to_generate_questions:
187-
logger.debug(
188-
"`response` is missing in the input dataframe columns, generating questions from the model is required."
189-
)
190-
if not student_model or not student_openai_client:
191-
raise ValueError(
192-
"provided dataset doesn't contain the model `response`, but either `student_model` or `student_openai_client` wasn't provided for inference"
193-
)
194-
195-
# if the student model was provided then we always generate regardless
196-
if student_model:
197-
if not student_openai_client:
198-
raise ValueError(
199-
"`student_model` was specified but `student_openai_client` was not provided"
200-
)
201-
input_df = self._generate_answers_from_model(
202-
input_df, student_model, student_openai_client
166+
# this looks similar to validate_dataset but here we want an exact match, not a subset
167+
required_keys = {"user_input", "reference", "response"}
168+
columns = set(input_df.columns)
169+
if columns != required_keys:
170+
raise ValueError(
171+
f"Input Dataset can only have the following keys: {', '.join(required_keys)}. Keys provided were: {', '.join(columns)}"
203172
)
204173

174+
205175
if not run_config:
206176
# we set extreme timeout/retry values by default since OpenAI tier-1 rate limits
207177
# are horrible and will result in half of our evaluation results being NaN or 0
@@ -229,15 +199,25 @@ def run(
229199
)
230200
return results
231201

232-
def _generate_answers_from_model(
233-
self,
202+
@staticmethod
203+
def generate_answers_from_model(
234204
questions: DataFrame,
235-
student_model: ModelConfig,
236-
student_openai_client: OpenAIClient,
205+
model_config: ModelConfig,
206+
openai_client: OpenAIClient,
237207
) -> DataFrame:
238208
"""
239209
Given a DataFrame containing `user_input` columns, generates responses from the given model
240210
and returns a new DataFrame containing its answers in the `response` column.
211+
212+
Args:
213+
questions: (DataFrame):
214+
Questions and refernce answers to be returned with the responses from the model
215+
model_config: (ModelConfig):
216+
Configuration settings for the model when getting responses.
217+
openai_client (openai.Client | None, optional):
218+
The client to use when generating questions from the model, must be compatible with the OpenAI API.
219+
Returns:
220+
DataFrame with user_input, reference, and response columns. Responses for the user_input from the model
241221
"""
242222
# initialize response to write into
243223
updated_df = questions.copy()
@@ -247,17 +227,17 @@ def _generate_answers_from_model(
247227
messages: List[ChatCompletionMessageParam] = [
248228
{
249229
"role": "system",
250-
"content": student_model.system_prompt,
230+
"content": model_config.system_prompt,
251231
},
252232
{"role": "user", "content": qna["user_input"]},
253233
]
254-
response = student_openai_client.chat.completions.create(
234+
response = openai_client.chat.completions.create(
255235
messages=messages,
256-
model=student_model.model_name,
236+
model=model_config.model_name,
257237
# specify the seed so we can at least try to have some reproducibility when the clients support it
258-
seed=42,
259-
max_tokens=student_model.max_tokens,
260-
temperature=student_model.temperature,
238+
seed=model_config.seed,
239+
max_tokens=model_config.max_tokens,
240+
temperature=model_config.temperature,
261241
)
262242
updated_df.at[i, "response"] = response.choices[0].message.content
263243
return updated_df

0 commit comments

Comments
 (0)