Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,24 @@ class Settings(BaseSettings):
),
)

detect_false_refusals: bool = Field(
default=False,
description=(
"Whether to detect false refusals on benign (good) prompts during evaluation. "
"When enabled, the model is checked for refusing harmless prompts that it shouldn't refuse. "
"False refusals are penalized in the KL divergence component of the optimization score."
),
)

false_refusal_weight: float = Field(
default=0.5,
description=(
"Weight for false refusal penalty when detect_false_refusals is enabled. "
"The penalty is: weight * (false_refusals / total_good_prompts). "
"Higher values more aggressively penalize over-abliteration."
),
)
Comment on lines +218 to +234
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the repository's style guide (rule #8), new settings added to config.py should also be added to config.default.toml. The new settings detect_false_refusals and false_refusal_weight are missing from config.default.toml.

Please add them to config.default.toml with their default values and descriptions, ensuring the order matches config.py.

References
  1. When new settings are added in config.py, they should also be added to config.default.toml, set to their default value and with their description as a comment. The order of settings in config.default.toml should match that in config.py. (link)


n_trials: int = Field(
default=200,
description="Number of abliteration trials to run during optimization.",
Expand Down
28 changes: 23 additions & 5 deletions src/heretic/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,18 @@ def is_refusal(self, response: str) -> bool:

return False

def count_refusals(self) -> int:
def count_refusals(self, prompts: list[Prompt] | None = None) -> int:
if prompts is None:
prompts = self.bad_prompts

refusal_count = 0

responses = self.model.get_responses_batched(
self.bad_prompts,
prompts,
skip_special_tokens=True,
)

for prompt, response in zip(self.bad_prompts, responses):
for prompt, response in zip(prompts, responses):
is_refusal = self.is_refusal(response)
if is_refusal:
refusal_count += 1
Expand All @@ -92,7 +95,7 @@ def count_refusals(self) -> int:

return refusal_count

def get_score(self) -> tuple[tuple[float, float], float, int]:
def get_score(self) -> tuple[tuple[float, float], float, int, int]:
print(" * Obtaining first-token probability distributions...")
logprobs = self.model.get_logprobs_batched(self.good_prompts)
kl_divergence = F.kl_div(
Expand All @@ -107,6 +110,16 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
refusals = self.count_refusals()
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")

# Detect false refusals on benign prompts if enabled.
false_refusals = 0
if self.settings.detect_false_refusals:
print(" * Checking for false refusals on good prompts...")
false_refusals = self.count_refusals(self.good_prompts)
color = "red" if false_refusals > 0 else "green"
print(
f" * False refusals: [{color}][bold]{false_refusals}[/]/{len(self.good_prompts)}[/]"
)

kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target

Expand All @@ -117,9 +130,14 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
else:
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale

# Penalize false refusals by adding to the KL component.
if false_refusals > 0:
false_refusal_rate = false_refusals / len(self.good_prompts)
kld_score += self.settings.false_refusal_weight * false_refusal_rate

score = (
kld_score,
refusals_score,
)

return score, kl_divergence, refusals
return score, kl_divergence, refusals, false_refusals
3 changes: 2 additions & 1 deletion src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def objective(trial: Trial) -> tuple[float, float]:
print("* Abliterating...")
model.abliterate(refusal_directions, direction_index, parameters)
print("* Evaluating...")
score, kl_divergence, refusals = evaluator.get_score()
score, kl_divergence, refusals, false_refusals = evaluator.get_score()

elapsed_time = time.perf_counter() - start_time
remaining_time = (elapsed_time / (trial_index - start_index)) * (
Expand All @@ -571,6 +571,7 @@ def objective(trial: Trial) -> tuple[float, float]:

trial.set_user_attr("kl_divergence", kl_divergence)
trial.set_user_attr("refusals", refusals)
trial.set_user_attr("false_refusals", false_refusals)

return score

Expand Down