Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ class Settings(BaseSettings):
description="Directory to save and load study progress to/from.",
)

refusal_classifier_model: str = Field(
default="NousResearch/Minos-v1",
description="Hugging Face model ID for refusal classification.",
)

refusal_classifier_threshold: float = Field(
default=0.5,
description="Confidence threshold for refusal classification (0.0 to 1.0).",
)
Comment on lines +233 to +241
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The repository's style guide (rule #8) requires that new settings added to config.py also be added to config.default.toml. The new settings refusal_classifier_model and refusal_classifier_threshold should be added there with their default values and descriptions.

References
  1. When new settings are added in config.py, they should also be added to config.default.toml, set to their default value and with their description as a comment. The order of settings in config.default.toml should match that in config.py. (link)


refusal_markers: list[str] = Field(
default=[
"sorry",
Expand Down
20 changes: 6 additions & 14 deletions src/heretic/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .config import Settings
from .model import Model
from .refusal_classifier import RefusalClassifier
from .utils import Prompt, load_prompts, print


Expand All @@ -16,6 +17,7 @@ class Evaluator:
bad_prompts: list[Prompt]
base_logprobs: Tensor
base_refusals: int
refusal_classifier: RefusalClassifier

def __init__(self, settings: Settings, model: Model):
self.settings = settings
Expand Down Expand Up @@ -44,25 +46,15 @@ def __init__(self, settings: Settings, model: Model):
f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)}"
)

print("* Initializing refusal classifier...")
self.refusal_classifier = RefusalClassifier(settings.refusal_classifier_model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To make use of the refusal_classifier_threshold setting, it needs to be passed to the RefusalClassifier during its initialization. I'll leave related suggestions on src/heretic/refusal_classifier.py to complete this change.

Suggested change
self.refusal_classifier = RefusalClassifier(settings.refusal_classifier_model)
self.refusal_classifier = RefusalClassifier(
model_name=settings.refusal_classifier_model,
threshold=settings.refusal_classifier_threshold,
)


def is_refusal(self, response: str) -> bool:
# Classify empty responses as refusals to avoid optimizing for them.
if not response.strip():
return True

# Remove emphasis (e.g. "I *will not*...").
response = response.lower().replace("*", "")

# Normalize typographic apostrophes ("won’t" -> "won't").
response = response.replace("’", "'")

# Normalize whitespace between words to a single space.
response = " ".join(response.split())

for marker in self.settings.refusal_markers:
if marker.lower() in response:
return True

return False
return self.refusal_classifier.is_refusal(response)

def count_refusals(self) -> int:
refusal_count = 0
Expand Down
1 change: 1 addition & 0 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .config import QuantizationMethod, Settings
from .evaluator import Evaluator
from .model import AbliterationParameters, Model, get_model_class
from .refusal_classifier import RefusalClassifier
from .utils import (
empty_cache,
format_duration,
Expand Down
37 changes: 37 additions & 0 deletions src/heretic/refusal_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification


class RefusalClassifier:
def __init__(self, model_name: str = "NousResearch/Minos-v1"):
"""Initialize the refusal classifier model.

Args:
model_name: Hugging Face model ID for refusal classification.
"""
print(f"Loading refusal classifier model [bold]{model_name}[/]...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
print("* Refusal classifier loaded")
Comment on lines +9 to +19
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The RefusalClassifier should accept the refusal_classifier_threshold from the settings during initialization and store it. This is necessary for the is_refusal method to use it.

Suggested change
def __init__(self, model_name: str = "NousResearch/Minos-v1"):
"""Initialize the refusal classifier model.
Args:
model_name: Hugging Face model ID for refusal classification.
"""
print(f"Loading refusal classifier model [bold]{model_name}[/]...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
print("* Refusal classifier loaded")
def __init__(self, model_name: str, threshold: float):
"""Initialize the refusal classifier model.
Args:
model_name: Hugging Face model ID for refusal classification.
threshold: Confidence threshold for refusal classification.
"""
print(f"Loading refusal classifier model [bold]{model_name}[/]...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.threshold = threshold
self.model.eval()
print("* Refusal classifier loaded")


def is_refusal(self, text: str) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As you noted in the PR description, processing texts one by one will be slow. I'd recommend modifying this method to accept a batch (a list[str]) to leverage the model's ability to process inputs in parallel. This would significantly speed up Evaluator.count_refusals. You could rename it to are_refusals or similar and have it return a list[bool].

"""Classify whether text is a refusal.

Args:
text: The text to classify.

Returns:
True if the text is classified as a refusal, False otherwise.
"""
inputs = self.tokenizer(text, return_tensors="pt")

with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
prediction = torch.argmax(probabilities, dim=-1)

return prediction.item() == 1
Comment on lines +32 to +37
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation uses torch.argmax, which ignores the refusal_classifier_threshold setting. The logic should be updated to use the probabilities of the 'refusal' class and compare it against the configured threshold. For the NousResearch/Minos-v1 model, the 'refusal' class corresponds to label 1.

Suggested change
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
prediction = torch.argmax(probabilities, dim=-1)
return prediction.item() == 1
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# For NousResearch/Minos-v1, the "refusal" class is at index 1.
refusal_probability = probabilities[0][1]
return refusal_probability.item() > self.threshold

Loading