From 88d7ed77a88dd1d29520b673825b0b607440301c Mon Sep 17 00:00:00 2001 From: Galunid Date: Thu, 12 Mar 2026 11:52:43 +0100 Subject: [PATCH] Use BERT based model for refusal classification --- src/heretic/config.py | 10 +++++++++ src/heretic/evaluator.py | 20 +++++------------ src/heretic/main.py | 1 + src/heretic/refusal_classifier.py | 37 +++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 14 deletions(-) create mode 100644 src/heretic/refusal_classifier.py diff --git a/src/heretic/config.py b/src/heretic/config.py index 8ed3f80c..f11a9193 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -230,6 +230,16 @@ class Settings(BaseSettings): description="Directory to save and load study progress to/from.", ) + refusal_classifier_model: str = Field( + default="NousResearch/Minos-v1", + description="Hugging Face model ID for refusal classification.", + ) + + refusal_classifier_threshold: float = Field( + default=0.5, + description="Confidence threshold for refusal classification (0.0 to 1.0).", + ) + refusal_markers: list[str] = Field( default=[ "sorry", diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index f2a8a258..65f73866 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -6,6 +6,7 @@ from .config import Settings from .model import Model +from .refusal_classifier import RefusalClassifier from .utils import Prompt, load_prompts, print @@ -16,6 +17,7 @@ class Evaluator: bad_prompts: list[Prompt] base_logprobs: Tensor base_refusals: int + refusal_classifier: RefusalClassifier def __init__(self, settings: Settings, model: Model): self.settings = settings @@ -44,25 +46,15 @@ def __init__(self, settings: Settings, model: Model): f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)}" ) + print("* Initializing refusal classifier...") + self.refusal_classifier = RefusalClassifier(settings.refusal_classifier_model) + def is_refusal(self, response: str) -> bool: # Classify empty responses as refusals to avoid optimizing for them. if not response.strip(): return True - # Remove emphasis (e.g. "I *will not*..."). - response = response.lower().replace("*", "") - - # Normalize typographic apostrophes ("won’t" -> "won't"). - response = response.replace("’", "'") - - # Normalize whitespace between words to a single space. - response = " ".join(response.split()) - - for marker in self.settings.refusal_markers: - if marker.lower() in response: - return True - - return False + return self.refusal_classifier.is_refusal(response) def count_refusals(self) -> int: refusal_count = 0 diff --git a/src/heretic/main.py b/src/heretic/main.py index c480888c..77e28445 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -39,6 +39,7 @@ from .config import QuantizationMethod, Settings from .evaluator import Evaluator from .model import AbliterationParameters, Model, get_model_class +from .refusal_classifier import RefusalClassifier from .utils import ( empty_cache, format_duration, diff --git a/src/heretic/refusal_classifier.py b/src/heretic/refusal_classifier.py new file mode 100644 index 00000000..0bd9ef89 --- /dev/null +++ b/src/heretic/refusal_classifier.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: AGPL-3.0-or-later +# Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors + +import torch +from transformers import AutoTokenizer, AutoModelForSequenceClassification + + +class RefusalClassifier: + def __init__(self, model_name: str = "NousResearch/Minos-v1"): + """Initialize the refusal classifier model. + + Args: + model_name: Hugging Face model ID for refusal classification. + """ + print(f"Loading refusal classifier model [bold]{model_name}[/]...") + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSequenceClassification.from_pretrained(model_name) + self.model.eval() + print("* Refusal classifier loaded") + + def is_refusal(self, text: str) -> bool: + """Classify whether text is a refusal. + + Args: + text: The text to classify. + + Returns: + True if the text is classified as a refusal, False otherwise. + """ + inputs = self.tokenizer(text, return_tensors="pt") + + with torch.no_grad(): + outputs = self.model(**inputs) + probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) + prediction = torch.argmax(probabilities, dim=-1) + + return prediction.item() == 1