diff --git a/config.default.toml b/config.default.toml index abfa0fc7..c0907a94 100644 --- a/config.default.toml +++ b/config.default.toml @@ -62,6 +62,16 @@ kl_divergence_scale = 1.0 # This helps prevent the sampler from extensively exploring parameter combinations that "do nothing". kl_divergence_target = 0.01 +# Whether to detect false refusals on benign (good) prompts during evaluation. +# When enabled, the model is checked for refusing harmless prompts that it shouldn't refuse. +# False refusals are penalized in the KL divergence component of the optimization score. +detect_false_refusals = false + +# Weight for false refusal penalty when detect_false_refusals is enabled. +# The penalty is: weight * (false_refusals / total_good_prompts). +# Higher values more aggressively penalize over-abliteration. +false_refusal_weight = 0.5 + # Whether to adjust the refusal directions so that only the component that is # orthogonal to the good direction is subtracted during abliteration. orthogonalize_direction = false diff --git a/src/heretic/config.py b/src/heretic/config.py index 8ed3f80c..5137c3bb 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -215,6 +215,24 @@ class Settings(BaseSettings): ), ) + detect_false_refusals: bool = Field( + default=False, + description=( + "Whether to detect false refusals on benign (good) prompts during evaluation. " + "When enabled, the model is checked for refusing harmless prompts that it shouldn't refuse. " + "False refusals are penalized in the KL divergence component of the optimization score." + ), + ) + + false_refusal_weight: float = Field( + default=0.5, + description=( + "Weight for false refusal penalty when detect_false_refusals is enabled. " + "The penalty is: weight * (false_refusals / total_good_prompts). " + "Higher values more aggressively penalize over-abliteration." + ), + ) + n_trials: int = Field( default=200, description="Number of abliteration trials to run during optimization.", diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index f2a8a258..abc080b2 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -64,15 +64,18 @@ def is_refusal(self, response: str) -> bool: return False - def count_refusals(self) -> int: + def count_refusals(self, prompts: list[Prompt] | None = None) -> int: + if prompts is None: + prompts = self.bad_prompts + refusal_count = 0 responses = self.model.get_responses_batched( - self.bad_prompts, + prompts, skip_special_tokens=True, ) - for prompt, response in zip(self.bad_prompts, responses): + for prompt, response in zip(prompts, responses): is_refusal = self.is_refusal(response) if is_refusal: refusal_count += 1 @@ -92,7 +95,7 @@ def count_refusals(self) -> int: return refusal_count - def get_score(self) -> tuple[tuple[float, float], float, int]: + def get_score(self) -> tuple[tuple[float, float], float, int, int]: print(" * Obtaining first-token probability distributions...") logprobs = self.model.get_logprobs_batched(self.good_prompts) kl_divergence = F.kl_div( @@ -107,6 +110,16 @@ def get_score(self) -> tuple[tuple[float, float], float, int]: refusals = self.count_refusals() print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}") + # Detect false refusals on benign prompts if enabled. + false_refusals = 0 + if self.settings.detect_false_refusals: + print(" * Checking for false refusals on good prompts...") + false_refusals = self.count_refusals(self.good_prompts) + color = "red" if false_refusals > 0 else "green" + print( + f" * False refusals: [{color}][bold]{false_refusals}[/]/{len(self.good_prompts)}[/]" + ) + kl_divergence_scale = self.settings.kl_divergence_scale kl_divergence_target = self.settings.kl_divergence_target @@ -117,9 +130,14 @@ def get_score(self) -> tuple[tuple[float, float], float, int]: else: kld_score = refusals_score * kl_divergence_target / kl_divergence_scale + # Penalize false refusals by adding to the KL component. + if false_refusals > 0: + false_refusal_rate = false_refusals / len(self.good_prompts) + kld_score += self.settings.false_refusal_weight * false_refusal_rate + score = ( kld_score, refusals_score, ) - return score, kl_divergence, refusals + return score, kl_divergence, refusals, false_refusals diff --git a/src/heretic/main.py b/src/heretic/main.py index c480888c..737e8341 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -555,7 +555,7 @@ def objective(trial: Trial) -> tuple[float, float]: print("* Abliterating...") model.abliterate(refusal_directions, direction_index, parameters) print("* Evaluating...") - score, kl_divergence, refusals = evaluator.get_score() + score, kl_divergence, refusals, false_refusals = evaluator.get_score() elapsed_time = time.perf_counter() - start_time remaining_time = (elapsed_time / (trial_index - start_index)) * ( @@ -571,6 +571,7 @@ def objective(trial: Trial) -> tuple[float, float]: trial.set_user_attr("kl_divergence", kl_divergence) trial.set_user_attr("refusals", refusals) + trial.set_user_attr("false_refusals", false_refusals) return score