Skip to content

LazerLambda/Promptzl

Repository files navigation

GitHub - License Docs - GitHub.io Tests Passing PyPI - Python Version PyPI - Package Version

Pr🥨mptzl

Turn state-of-the-art LLMs into zero+-shot PyTorch classifiers in just a few lines of code.

Promptzl offers:

  • 🤖 Zero+-shot classification with LLMs
  • 🤗 Turning causal and masked LMs into classifiers without any training
  • 📦 Batch processing on your device for efficiency
  • 🚀 Speed-up over calling an online API
  • 🔎 Transparency and accessibility by using the model locally
  • 📈 Distribution over labels
  • ✂️ No need to extract the predictions from the answer.

For more information, check out the official documentation.

Installation

pip install -U promptzl

Getting Started

In just a few lines of code, you can transform a LLM of choice into an old-school classifier with all it's desirable properties:

Set up the dataset:

from datasets import Dataset

dataset = Dataset.from_dict(
    {
        'text': [
            "The food was absolutely wonderful, from preparation to presentation, very pleasing.",
            "The service was a bit slow, but the food made up for it. Highly recommend the pasta!",
            "The restaurant was too noisy and the food was mediocre at best. Not worth the price.",
        ],
        'label': [1, 1, 0]
    }
)

Define a prompt for guiding the language model to the correct predictions:

from promptzl import FnVbzPair, Vbz
prompt = FnVbzPair(
    lambda e: f"""Restaurant review classification into categories 'positive' or 'negative'.

    'Best pretzls in town!'='positive'
    'Rude staff, horrible food.'='negative'

    '{e['text']}'=""",
    Vbz({0: ["negative"], 1: ["positive"]}))

Initialize a model:

from promptzl import CausalLM4Classification
model = CausalLM4Classification(
    'HuggingFaceTB/SmolLM2-1.7B',
    prompt=prompt)

Classify the data:

from sklearn.metrics import accuracy_score
output = model.classify(dataset, show_progress_bar=True, batch_size=1)
accuracy_score(dataset['label'], output.predictions)
1.0

For more detailed tutorials, check out the documentation!