Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,30 @@ class Settings(BaseSettings):
),
)

target_components: list[str] = Field(
default=["attn.o_proj", "mlp.down_proj"],
description=(
"List of component names to target for abliteration. "
'Currently supported values are "attn.o_proj" and "mlp.down_proj".'
),
)

use_ara: bool = Field(
default=True,
description=(
"Whether to use Arbitrary-Rank Ablation (ARA), an abliteration method based on matrix optimization, "
"instead of traditional directional ablation."
),
)
Comment thread
p-e-w marked this conversation as resolved.

use_piqa: bool = Field(
default=False,
description=(
"Whether to use the Physical Interaction: Question Answering (PIQA) benchmark "
"as the quality metric instead of the Kullback-Leibler divergence."
),
)

orthogonalize_direction: bool = Field(
default=False,
description=(
Expand All @@ -197,7 +221,7 @@ class Settings(BaseSettings):
)

row_normalization: RowNormalization = Field(
default=RowNormalization.NONE,
default=RowNormalization.FULL,
description=(
"How to apply row normalization of the weights. Options: "
'"none" (no normalization), '
Expand Down
81 changes: 53 additions & 28 deletions src/heretic/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors

import lm_eval
import torch.nn.functional as F
from lm_eval.models.huggingface import HFLM
from torch import Tensor

from .config import Settings
Expand All @@ -21,15 +23,16 @@ def __init__(self, settings: Settings, model: Model):
self.settings = settings
self.model = model

print()
print(
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
)
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
if not settings.use_piqa:
print()
print(
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
)
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")

print("* Obtaining first-token probability distributions...")
self.base_logprobs = model.get_logprobs_batched(self.good_prompts)
print("* Obtaining first-token probability distributions...")
self.base_logprobs = model.get_logprobs_batched(self.good_prompts)

print()
print(
Expand Down Expand Up @@ -93,35 +96,57 @@ def count_refusals(self) -> int:
return refusal_count

def get_score(self) -> tuple[tuple[float, float], float, int]:
print(" * Obtaining first-token probability distributions...")
logprobs = self.model.get_logprobs_batched(self.good_prompts)
kl_divergence = F.kl_div(
logprobs,
self.base_logprobs,
reduction="batchmean",
log_target=True,
).item()
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
if self.settings.use_piqa:
print(" * Running PIQA benchmark...")
hflm = HFLM(
pretrained=self.model.model, # ty:ignore[invalid-argument-type]
tokenizer=self.model.tokenizer, # ty:ignore[invalid-argument-type]
batch_size="auto",
)
results = lm_eval.simple_evaluate(
model=hflm,
tasks=["piqa"],
)
piqa_acc_norm: float = results["results"]["piqa"]["acc_norm,none"]
print(f" * PIQA acc_norm: [bold]{piqa_acc_norm:.4f}[/]")
else:
print(" * Obtaining first-token probability distributions...")
logprobs = self.model.get_logprobs_batched(self.good_prompts)
kl_divergence = F.kl_div(
logprobs,
self.base_logprobs,
reduction="batchmean",
log_target=True,
).item()
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")

print(" * Counting model refusals...")
refusals = self.count_refusals()
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")

kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target

refusals_score = (
refusals / self.base_refusals if self.base_refusals > 0 else float(refusals)
)

if kl_divergence >= kl_divergence_target:
kld_score = kl_divergence / kl_divergence_scale
if self.settings.use_piqa:
score = (
-piqa_acc_norm,
refusals_score,
)

return score, -piqa_acc_norm, refusals
else:
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target

score = (
kld_score,
refusals_score,
)
if kl_divergence >= kl_divergence_target:
kld_score = kl_divergence / kl_divergence_scale
else:
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale

score = (
kld_score,
refusals_score,
)

return score, kl_divergence, refusals
return score, kl_divergence, refusals
Loading
Loading