-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Use BERT-based model for refusal classification #224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |||||||||||
|
|
||||||||||||
| from .config import Settings | ||||||||||||
| from .model import Model | ||||||||||||
| from .refusal_classifier import RefusalClassifier | ||||||||||||
| from .utils import Prompt, load_prompts, print | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
@@ -16,6 +17,7 @@ class Evaluator: | |||||||||||
| bad_prompts: list[Prompt] | ||||||||||||
| base_logprobs: Tensor | ||||||||||||
| base_refusals: int | ||||||||||||
| refusal_classifier: RefusalClassifier | ||||||||||||
|
|
||||||||||||
| def __init__(self, settings: Settings, model: Model): | ||||||||||||
| self.settings = settings | ||||||||||||
|
|
@@ -44,25 +46,15 @@ def __init__(self, settings: Settings, model: Model): | |||||||||||
| f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)}" | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| print("* Initializing refusal classifier...") | ||||||||||||
| self.refusal_classifier = RefusalClassifier(settings.refusal_classifier_model) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make use of the
Suggested change
|
||||||||||||
|
|
||||||||||||
| def is_refusal(self, response: str) -> bool: | ||||||||||||
| # Classify empty responses as refusals to avoid optimizing for them. | ||||||||||||
| if not response.strip(): | ||||||||||||
| return True | ||||||||||||
|
|
||||||||||||
| # Remove emphasis (e.g. "I *will not*..."). | ||||||||||||
| response = response.lower().replace("*", "") | ||||||||||||
|
|
||||||||||||
| # Normalize typographic apostrophes ("won’t" -> "won't"). | ||||||||||||
| response = response.replace("’", "'") | ||||||||||||
|
|
||||||||||||
| # Normalize whitespace between words to a single space. | ||||||||||||
| response = " ".join(response.split()) | ||||||||||||
|
|
||||||||||||
| for marker in self.settings.refusal_markers: | ||||||||||||
| if marker.lower() in response: | ||||||||||||
| return True | ||||||||||||
|
|
||||||||||||
| return False | ||||||||||||
| return self.refusal_classifier.is_refusal(response) | ||||||||||||
|
|
||||||||||||
| def count_refusals(self) -> int: | ||||||||||||
| refusal_count = 0 | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,37 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: AGPL-3.0-or-later | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| class RefusalClassifier: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, model_name: str = "NousResearch/Minos-v1"): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Initialize the refusal classifier model. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| model_name: Hugging Face model ID for refusal classification. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Loading refusal classifier model [bold]{model_name}[/]...") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model.eval() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| print("* Refusal classifier loaded") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+9
to
+19
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def is_refusal(self, text: str) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As you noted in the PR description, processing texts one by one will be slow. I'd recommend modifying this method to accept a batch (a |
||||||||||||||||||||||||||||||||||||||||||||||||||
| """Classify whether text is a refusal. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| text: The text to classify. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| True if the text is classified as a refusal, False otherwise. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| inputs = self.tokenizer(text, return_tensors="pt") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = self.model(**inputs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| prediction = torch.argmax(probabilities, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| return prediction.item() == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+32
to
+37
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation uses
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The repository's style guide (rule #8) requires that new settings added to
config.pyalso be added toconfig.default.toml. The new settingsrefusal_classifier_modelandrefusal_classifier_thresholdshould be added there with their default values and descriptions.References
config.py, they should also be added toconfig.default.toml, set to their default value and with their description as a comment. The order of settings inconfig.default.tomlshould match that inconfig.py. (link)