From 8954f8cd2e248ce59d6421fe47253c9c525fab29 Mon Sep 17 00:00:00 2001 From: KewkLW Date: Tue, 3 Mar 2026 22:38:58 -0800 Subject: [PATCH 1/2] feat: detect false refusals on benign prompts during optimization Add --detect-false-refusals flag that checks whether abliteration causes the model to refuse harmless prompts it shouldn't refuse. False refusals are penalized in the KL divergence score component to steer the optimizer away from over-abliteration. The penalty is: false_refusal_weight * (false_refusals / good_prompts). Default weight is 0.5, configurable via --false-refusal-weight. Disabled by default, preserving existing behavior. --- src/heretic/config.py | 18 ++++++++++++++++++ src/heretic/evaluator.py | 28 +++++++++++++++++++++++----- src/heretic/main.py | 3 ++- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/heretic/config.py b/src/heretic/config.py index 8ed3f80c..5137c3bb 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -215,6 +215,24 @@ class Settings(BaseSettings): ), ) + detect_false_refusals: bool = Field( + default=False, + description=( + "Whether to detect false refusals on benign (good) prompts during evaluation. " + "When enabled, the model is checked for refusing harmless prompts that it shouldn't refuse. " + "False refusals are penalized in the KL divergence component of the optimization score." + ), + ) + + false_refusal_weight: float = Field( + default=0.5, + description=( + "Weight for false refusal penalty when detect_false_refusals is enabled. " + "The penalty is: weight * (false_refusals / total_good_prompts). " + "Higher values more aggressively penalize over-abliteration." + ), + ) + n_trials: int = Field( default=200, description="Number of abliteration trials to run during optimization.", diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index f2a8a258..abc080b2 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -64,15 +64,18 @@ def is_refusal(self, response: str) -> bool: return False - def count_refusals(self) -> int: + def count_refusals(self, prompts: list[Prompt] | None = None) -> int: + if prompts is None: + prompts = self.bad_prompts + refusal_count = 0 responses = self.model.get_responses_batched( - self.bad_prompts, + prompts, skip_special_tokens=True, ) - for prompt, response in zip(self.bad_prompts, responses): + for prompt, response in zip(prompts, responses): is_refusal = self.is_refusal(response) if is_refusal: refusal_count += 1 @@ -92,7 +95,7 @@ def count_refusals(self) -> int: return refusal_count - def get_score(self) -> tuple[tuple[float, float], float, int]: + def get_score(self) -> tuple[tuple[float, float], float, int, int]: print(" * Obtaining first-token probability distributions...") logprobs = self.model.get_logprobs_batched(self.good_prompts) kl_divergence = F.kl_div( @@ -107,6 +110,16 @@ def get_score(self) -> tuple[tuple[float, float], float, int]: refusals = self.count_refusals() print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}") + # Detect false refusals on benign prompts if enabled. + false_refusals = 0 + if self.settings.detect_false_refusals: + print(" * Checking for false refusals on good prompts...") + false_refusals = self.count_refusals(self.good_prompts) + color = "red" if false_refusals > 0 else "green" + print( + f" * False refusals: [{color}][bold]{false_refusals}[/]/{len(self.good_prompts)}[/]" + ) + kl_divergence_scale = self.settings.kl_divergence_scale kl_divergence_target = self.settings.kl_divergence_target @@ -117,9 +130,14 @@ def get_score(self) -> tuple[tuple[float, float], float, int]: else: kld_score = refusals_score * kl_divergence_target / kl_divergence_scale + # Penalize false refusals by adding to the KL component. + if false_refusals > 0: + false_refusal_rate = false_refusals / len(self.good_prompts) + kld_score += self.settings.false_refusal_weight * false_refusal_rate + score = ( kld_score, refusals_score, ) - return score, kl_divergence, refusals + return score, kl_divergence, refusals, false_refusals diff --git a/src/heretic/main.py b/src/heretic/main.py index c480888c..737e8341 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -555,7 +555,7 @@ def objective(trial: Trial) -> tuple[float, float]: print("* Abliterating...") model.abliterate(refusal_directions, direction_index, parameters) print("* Evaluating...") - score, kl_divergence, refusals = evaluator.get_score() + score, kl_divergence, refusals, false_refusals = evaluator.get_score() elapsed_time = time.perf_counter() - start_time remaining_time = (elapsed_time / (trial_index - start_index)) * ( @@ -571,6 +571,7 @@ def objective(trial: Trial) -> tuple[float, float]: trial.set_user_attr("kl_divergence", kl_divergence) trial.set_user_attr("refusals", refusals) + trial.set_user_attr("false_refusals", false_refusals) return score From 19cc2973d8a613e22699f54405ee713153873b60 Mon Sep 17 00:00:00 2001 From: KewkLW Date: Tue, 3 Mar 2026 22:46:19 -0800 Subject: [PATCH 2/2] fix: add new settings to config.default.toml Add detect_false_refusals and false_refusal_weight entries to match the repository convention of mirroring all config.py settings. --- config.default.toml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/config.default.toml b/config.default.toml index abfa0fc7..c0907a94 100644 --- a/config.default.toml +++ b/config.default.toml @@ -62,6 +62,16 @@ kl_divergence_scale = 1.0 # This helps prevent the sampler from extensively exploring parameter combinations that "do nothing". kl_divergence_target = 0.01 +# Whether to detect false refusals on benign (good) prompts during evaluation. +# When enabled, the model is checked for refusing harmless prompts that it shouldn't refuse. +# False refusals are penalized in the KL divergence component of the optimization score. +detect_false_refusals = false + +# Weight for false refusal penalty when detect_false_refusals is enabled. +# The penalty is: weight * (false_refusals / total_good_prompts). +# Higher values more aggressively penalize over-abliteration. +false_refusal_weight = 0.5 + # Whether to adjust the refusal directions so that only the component that is # orthogonal to the good direction is subtracted during abliteration. orthogonalize_direction = false