Skip to content

Commit

Permalink
Improve standardization of classification results.
Browse files Browse the repository at this point in the history
- Add JSON schema as response format for OpenAI API
- Improve classification prompts
- Handle errors when OpenAI response is too long/short
- Refactor function classify_mnemonics_api
- Improve logging
  • Loading branch information
chiffonng committed Oct 27, 2024
1 parent de84c21 commit b9c7070
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ indent_size = 2
trim_trailing_whitespace = true
insert_final_newline = true
[*.py]
indent_size = 4
indent_size = 4
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Ignore Jupyter Notebooks from Github Linguist Stats
*.ipynb linguist-vendored
*.ipynb linguist-vendored
7 changes: 3 additions & 4 deletions prompts/classify_mnemonics.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
prompts:
system: |
You are an expert in English mnemonics classification. Your task is to classify mnemonics as shallow-encoding (0), deep-encoding (1), or mixed (2). Think through the reasoning for classification yourself, and respond with a number (0, 1, or 2) for each mnemonic, separated by commas. If unsure, return -1. Do not include any other text in your response.
user: |
You are an expert in English mnemonics. Your task is to classify each mnemonic as one of the following: shallow-encoding (0), deep-encoding (1), mixed (2), or unsure (-1). Think through the reasoning for classification yourself, and respond consistently with the response format. You have to classify every mnemonic in the prompt, no more no less. If unsure, return -1. \n
Classify the mnemonics below based on the following criteria:\n
- Shallow (0): Focus on how the word sounds, looks, or rhymes.
- Deep (1): Focus on semantics, morphology, etymology, context (inferred meaning, imagery), related words (synonyms, antonyms, words with same roots). Repeating the word or using a similar-sounding word is NOT deep-encoding.
Expand All @@ -11,8 +10,8 @@ prompts:
- olfactory: Sounds like "old factory." The old factory had a strong smell, reminding workers of its olfactory history. Classification: shallow (0), since it's based on the sound.
- vacuous: Same Latin root "vacare" (empty) as "vacuum, vacant". His expression was as empty as a vacuum, showing no signs of thought. Classification: deep (1), since it only uses etymology and related words.
- malevolent: From male 'ill' + volent 'wishing' (as in "benevolent"). These male species are so violent that they always have evil plans. Classification: mixed (2) since it uses etymology and antonyms (deep-encoding), and the sounds of "male" and "violent" (shallow-encoding)\n
Mnemonics:
user: Mnemonics are seperated by a newline character. Please classify each mnemonic in the same order as they appear in the prompt.\n
model: "gpt-4o-mini"
temperature: 0.2
num_outputs: 1
batch_size: 50
250 changes: 178 additions & 72 deletions src/data_pipeline/mnemonic_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import logging
from pathlib import Path
from warnings import warn
from pydantic import BaseModel, ValidationError
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated

import pandas as pd
from dotenv import load_dotenv
from openai import OpenAI, RateLimitError
from openai import OpenAI, RateLimitError, LengthFinishReasonError, OpenAIError
from tenacity import (
after_log,
before_log,
Expand Down Expand Up @@ -39,16 +42,28 @@
)
logger.handlers[0].setFormatter(formatter)

logger.addHandler(logging.StreamHandler()) # Log to console

# Initialize OpenAI client
client = OpenAI()

# Load config and prompts
with open("prompts/classify_mnemonics.yaml", "r") as f:
classification_conf = safe_load(f)
CLASSIFY_SYSTEM_PROMPT = classification_conf["prompts"]["system"]
CLASSIFY_USER_PROMPT = classification_conf["prompts"]["user"]
classification_conf = safe_load(f) # dict of config
batch_size = classification_conf["batch_size"]


def validate_classification(value: int) -> int:
"""Validate classification value to be -1, 0, 1, or 2. Otherwise, return -1."""
return value if value in {-1, 0, 1, 2} else -1


ValidClassification = Annotated[int, AfterValidator(validate_classification)]


# Mnemonic classification schema
class ClassificationSchema(BaseModel):
"""Pydantic schema for the classification of mnemonics."""

classifications: list[ValidClassification]


def combine_key_value(path: str) -> list[str]:
Expand Down Expand Up @@ -85,16 +100,15 @@ def combine_key_value(path: str) -> list[str]:
return combined_col.to_list()


def create_batches(data: list[str], batch_size: int = 50) -> tuple[list[str], int]:
def create_batches(data: list[str], batch_size=batch_size) -> list[str]:
"""Build batches of text data to send to OpenAI's API.
Args:
data (list[str]): The list of data to process.
batch_size (int): The size of each batch. Defaults to 50.
batch_size (int, optional): The number of mnemonics to include in each batch. Defaults to batch_size read from the config.
Returns:
flattened_batches (list[str]): The list of batches, each item is a batch of text data
batch_size (int): The size of each batch
Raises:
ValueError: if no data is provided or if the batch size is invalid.
Expand All @@ -112,101 +126,200 @@ def create_batches(data: list[str], batch_size: int = 50) -> tuple[list[str], in
flattened_batches = ["\n".join(batch) for batch in batches]
logger.info(f"Created {len(batches)} batches of mnemonics.")

return flattened_batches, batch_size
return flattened_batches


@retry(
retry=retry_if_exception_type(RateLimitError, ValueError),
retry_error_callback=lambda x: logger.error(f"Exception during retries: {x}"),
retry=retry_if_exception_type(RateLimitError),
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, min=0, max=4), # 2^0 to 2^4 seconds
before=before_log(logger, logging.WARNING),
after=after_log(logger, logging.WARNING),
)
def classify_mnemonics_api(
batches: list[str],
batch_size: int,
):
"""Classify mnemonics using OpenAI's API, GPT-4o mini, and write results to a file (to save costs).
def classify_mnemonics_api(batches: list[str]):
"""Classify mnemonics using OpenAI's API, GPT-4o mini and return the responses as JSON array of numbers. Retry up to 3 times if rate limited.
Args:
batches (list[str]): The list of batches of mnemonics to categorize.
batch_size (int): The size of each batch.
Returns:
responses (str): The string of responses from OpenAI's API, formatted as a string of numbers separated by commas.
classification_by_batch (list[ValidClassification]): The list of parsed categories.
Raises:
ValueError:
- If the output file is not in parquet or csv
- If the input (batches) is not a list or collections.abc.Iterable of strings.
- If the input (batches) is not a list of strings.
"""
if not isinstance(batches, (list, str)):
raise ValueError("Batches must be a string or a list of strings.")
raise ValueError(
f"Batches must be a string or a list of strings. Current type: {type(batches)}"
)
batches = [batches] if isinstance(batches, str) else batches

logger.info(f"Processing {len(batches)} batches...")
responses = [
client.chat.completions.create(
model=classification_conf["model"],
messages=[
{"role": "system", "content": CLASSIFY_SYSTEM_PROMPT},
{"role": "user", "content": f"{CLASSIFY_USER_PROMPT}{batch}"},
],
max_completion_tokens=batch_size * 3 + 3, # 3-4 tokens per mnemonic
temperature=classification_conf["temperature"],
n=classification_conf["num_outputs"],
logger.info(
f"Configurations: batch_size={batch_size}, model={classification_conf['model']}, temperature={classification_conf['temperature']}, num_outputs={classification_conf['num_outputs']}."
)

classification_by_batch = []
for i, batch in tqdm(enumerate(batches), desc="Processing batches", unit="batch"):
classification_msg = get_structured_response(
i,
batch,
model_config=classification_conf,
response_format=ClassificationSchema,
)
.choices[0]
.message.content
for batch in tqdm(batches, desc="Processing batches", unit="batch")
]
return ",".join(responses)
classification_batch_i = parse_structured_response(classification_msg, batch, i)
classification_by_batch.extend(classification_batch_i)

logger.info(f"Returned {len(classification_by_batch)} classifications.")
return classification_by_batch


def parse_save_classification_results(
res_str: str, output_path: str | Path
) -> list[int]:
"""Parse comma-separated categories and save them to a file.
def get_structured_response(
i: int,
batch: str,
model_config: dict,
response_format: BaseModel = ClassificationSchema,
):
"""Get response from OpenAI API. Documentation: https://platform.openai.com/docs/guides/structured-outputs/how-to-use.
Args:
res_str (str): The string of numbers (which are the categories) separated by commas.
output_path (str | Path): The path to .csv or .parquet file to write the parsed.
i (int): The index of the batch.
batch (str): The batch of mnemonics to classify.
model_config (dict): The model configuration.
response_format (BaseModel, optional): The response format. Defaults to ClassificationSchema.
Returns:
structure_msg (message object from OpenAI's Response object): A structured message object.
"""
try:
structure_msg = (
client.beta.chat.completions.parse(
model=model_config["model"],
messages=[
{"role": "system", "content": model_config["prompts"]["system"]},
{
"role": "user",
"content": f"{model_config["prompts"]["user"]}{batch}",
},
],
max_tokens=batch_size * 3 + 1, # 3 tokens per mnemonic
temperature=model_config["temperature"],
n=model_config["num_outputs"],
response_format=response_format,
)
.choices[0]
.message
)
if structure_msg.refusal:
logger.error(f"Batch {i+1}: OpenAI refused to process the request.")
raise OpenAIError("OpenAI refused to process the request.")

return structure_msg

except Exception as e:
if isinstance(e, LengthFinishReasonError):
logger.error(f"LengthFinishReasonError: {e}")
raise ValueError(
"OpenAI run out of tokens. Please try: reducing the batch_size, or increasing the max_tokens parameter."
)
else:
logger.error(f"Exception: {e}")
raise e


def parse_structured_response(
structure_msg: object, batch: str, batch_index: int
) -> list[int | str]:
"""Parse the structured message from OpenAI's API.
Args:
structure_msg (message object from OpenAI's Response object): A structured message object.
batch (str): The batch of mnemonics.
batch_index (int): The index of the batch.
Returns:
scores (list[int]): The list of parsed numbers.
classification_batch_i (list[int|str]): The list of parsed categories.
"""
try:
if structure_msg.parsed:
classification_batch_i = structure_msg.parsed.classifications
batch_i_size = len(batch.split("\n"))
classification_i_size = len(classification_batch_i)

# Log batch debug info
logger.debug(
f"Batch {batch_index+1} with {batch_i_size} mnemonics: {classification_i_size} classifications."
)
logger.debug(
f"Batch {batch_index+1} classifications: {classification_batch_i}"
)
logger.debug(
f"Batch {batch_index+1} types: {type(classification_batch_i[0])}"
)

# Handle when the number of classifications does not match the number of mnemonics
if classification_i_size > batch_i_size:
logger.warning(
f"Batch {batch_index+1}: Number of classifications {classification_i_size} exceeds the number of mnemonics {batch_i_size}. Truncating to match the number of mnemonics..."
)
return classification_batch_i[:batch_i_size]

elif classification_i_size < batch_i_size:
logger.warning(
f"Batch {batch_index+1}: Number of classifications {classification_i_size} is less than the number of mnemonics {batch_i_size}. Padding with -1..."
)
return classification_batch_i + [-1] * (
batch_i_size - classification_i_size
)

else: # classification_i_size == batch_i_size
return classification_batch_i

except ValidationError as e:
logger.error(f"ValidationError: {e}")
raise ValueError(
f"Batch {batch_index+1}: The response didn't match the expected format. Check the logs for more details."
)


def save_structured_outputs(
outputs: list[ValidClassification], input_path: str | Path, output_path: str | Path
):
"""Save the classification results to an existing file of mnemonics.
Args:
outputs (list[ValidClassification]): The list of parsed categories.
input_path (str | Path): The path to the file containing the mnemonics.
output_path (str | Path): The path to .csv or .parquet file to write the parsed.
Raises:
ValueError: If the output file is not in parquet or csv format.
"""
logger.info(
f"Received {len(res_str)} characters from OpenAI's API. Preview: {res_str[:100]}"
)
categories = [int(c) for c in res_str.split(",") if c.strip().isdigit()]
if not all(c in {-1, 0, 1, 2} for c in categories):
raise ValueError("Parsed categories must be -1, 0, 1, or 2.")

# Set up output path
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)

# Read initial dataset to get the number of rows
f = which_file_exists(COMBINED_DATASET_CSV, COMBINED_DATASET_PARQUET)
df = pd.read_csv(f) if f.suffix == CSV_EXT else pd.read_parquet(f)
if len(df) != len(categories):
logger.error(
f"Number of rows in the file does not match the number of categories. Number of rows: {len(df)}, number of categories: {len(categories)}"
)
raise ValueError(
f"Number of rows in the file does not match the number of categories. Number of rows: {len(df)}, number of categories: {len(categories)}"
)
input_path = check_file_path(
input_path, new_ok=True, extensions=[PARQUET_EXT, CSV_EXT]
)
df = (
pd.read_csv(input_path)
if input_path.suffix == CSV_EXT
else pd.read_parquet(input_path)
)
if len(df) != len(outputs):
error_msg = f"Number of rows in the file does not match the number of categories. Number of rows: {len(df)}, number of categories: {len(outputs)}"
logger.error(error_msg)
raise ValueError(error_msg)

# Add the categories column and save to the requested format
df["category"] = categories
df["category"] = outputs
save_func = df.to_parquet if output_path.suffix == PARQUET_EXT else df.to_csv
save_func(output_path, index=False)
logger.info(f"Saved classification results to {str(output_path)}.")
return categories


def standardize_mnemonics_api(batches):
Expand All @@ -219,27 +332,20 @@ def diversify_mnemonics_api(batches):
raise NotImplementedError


def classify_mnemonics(
input_path: str, output_path: str, batch_size: int = 50, n: int = 1
) -> list[int]:
def classify_mnemonics(input_path: str, output_path: str):
"""End-to-end function for classifying mnemonics.
Args:
input_path (str): The path to the file containing the mnemonics.
output_path (str): The path to the file to save the classification results.
batch_size (int): The size of each batch. Defaults to 50.
n (int): The number of completions to generate. Defaults to 1.
Returns:
(list[int]): The list of parsed categories.
Raises:
ValueError: If the output file is not in parquet or csv format.
"""
data = combine_key_value(input_path)
batches, batch_size = create_batches(data, batch_size)
raw_response = classify_mnemonics_api(batches, batch_size)
return parse_save_classification_results(raw_response, output_path)
batches = create_batches(data)
classifications = classify_mnemonics_api(batches)
save_structured_outputs(classifications, input_path, output_path)


classify_mnemonics(COMBINED_DATASET_CSV, CLASSIFIED_DATASET_CSV)

0 comments on commit b9c7070

Please sign in to comment.