diff --git a/src/heretic/config.py b/src/heretic/config.py index 8b70499b..77de3c59 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -188,6 +188,30 @@ class Settings(BaseSettings): ), ) + target_components: list[str] = Field( + default=["attn.o_proj", "mlp.down_proj"], + description=( + "List of component names to target for abliteration. " + 'Currently supported values are "attn.o_proj" and "mlp.down_proj".' + ), + ) + + use_ara: bool = Field( + default=True, + description=( + "Whether to use Arbitrary-Rank Ablation (ARA), an abliteration method based on matrix optimization, " + "instead of traditional directional ablation." + ), + ) + + use_piqa: bool = Field( + default=False, + description=( + "Whether to use the Physical Interaction: Question Answering (PIQA) benchmark " + "as the quality metric instead of the Kullback-Leibler divergence." + ), + ) + orthogonalize_direction: bool = Field( default=False, description=( @@ -197,7 +221,7 @@ class Settings(BaseSettings): ) row_normalization: RowNormalization = Field( - default=RowNormalization.NONE, + default=RowNormalization.FULL, description=( "How to apply row normalization of the weights. Options: " '"none" (no normalization), ' diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index eced014e..5133a347 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors +import lm_eval import torch.nn.functional as F +from lm_eval.models.huggingface import HFLM from torch import Tensor from .config import Settings @@ -21,15 +23,16 @@ def __init__(self, settings: Settings, model: Model): self.settings = settings self.model = model - print() - print( - f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..." - ) - self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts) - print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") + if not settings.use_piqa: + print() + print( + f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..." + ) + self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts) + print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") - print("* Obtaining first-token probability distributions...") - self.base_logprobs = model.get_logprobs_batched(self.good_prompts) + print("* Obtaining first-token probability distributions...") + self.base_logprobs = model.get_logprobs_batched(self.good_prompts) print() print( @@ -93,35 +96,57 @@ def count_refusals(self) -> int: return refusal_count def get_score(self) -> tuple[tuple[float, float], float, int]: - print(" * Obtaining first-token probability distributions...") - logprobs = self.model.get_logprobs_batched(self.good_prompts) - kl_divergence = F.kl_div( - logprobs, - self.base_logprobs, - reduction="batchmean", - log_target=True, - ).item() - print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]") + if self.settings.use_piqa: + print(" * Running PIQA benchmark...") + hflm = HFLM( + pretrained=self.model.model, # ty:ignore[invalid-argument-type] + tokenizer=self.model.tokenizer, # ty:ignore[invalid-argument-type] + batch_size="auto", + ) + results = lm_eval.simple_evaluate( + model=hflm, + tasks=["piqa"], + ) + piqa_acc_norm: float = results["results"]["piqa"]["acc_norm,none"] + print(f" * PIQA acc_norm: [bold]{piqa_acc_norm:.4f}[/]") + else: + print(" * Obtaining first-token probability distributions...") + logprobs = self.model.get_logprobs_batched(self.good_prompts) + kl_divergence = F.kl_div( + logprobs, + self.base_logprobs, + reduction="batchmean", + log_target=True, + ).item() + print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]") print(" * Counting model refusals...") refusals = self.count_refusals() print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}") - kl_divergence_scale = self.settings.kl_divergence_scale - kl_divergence_target = self.settings.kl_divergence_target - refusals_score = ( refusals / self.base_refusals if self.base_refusals > 0 else float(refusals) ) - if kl_divergence >= kl_divergence_target: - kld_score = kl_divergence / kl_divergence_scale + if self.settings.use_piqa: + score = ( + -piqa_acc_norm, + refusals_score, + ) + + return score, -piqa_acc_norm, refusals else: - kld_score = refusals_score * kl_divergence_target / kl_divergence_scale + kl_divergence_scale = self.settings.kl_divergence_scale + kl_divergence_target = self.settings.kl_divergence_target - score = ( - kld_score, - refusals_score, - ) + if kl_divergence >= kl_divergence_target: + kld_score = kl_divergence / kl_divergence_scale + else: + kld_score = refusals_score * kl_divergence_target / kl_divergence_scale + + score = ( + kld_score, + refusals_score, + ) - return score, kl_divergence, refusals + return score, kl_divergence, refusals diff --git a/src/heretic/main.py b/src/heretic/main.py index fcc7e3d4..0abd11d7 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -51,9 +51,9 @@ from rich.traceback import install from .analyzer import Analyzer -from .config import QuantizationMethod, Settings +from .config import QuantizationMethod, RowNormalization, Settings from .evaluator import Evaluator -from .model import AbliterationParameters, Model, get_model_class +from .model import AbliterationParameters, ARAParameters, Model, get_model_class from .utils import ( empty_cache, format_duration, @@ -227,8 +227,9 @@ def run(): "[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]" ) - # We don't need gradients as we only do inference. - torch.set_grad_enabled(False) + if not settings.use_ara: + # We don't need gradients as we only do inference. + torch.set_grad_enabled(False) # While determining the optimal batch size, we will try many different batch sizes, # resulting in many computation graphs being compiled. Raising the limit (default = 8) @@ -451,40 +452,47 @@ def run(): evaluator.get_score() return - print() - print("Calculating per-layer refusal directions...") - print("* Obtaining residuals for good prompts...") - good_residuals = model.get_residuals_batched(good_prompts) - print("* Obtaining residuals for bad prompts...") - bad_residuals = model.get_residuals_batched(bad_prompts) - - good_means = good_residuals.mean(dim=0) - bad_means = bad_residuals.mean(dim=0) - - refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1) - - if settings.orthogonalize_direction: - # Implements https://huggingface.co/blog/grimjim/projected-abliteration - # Adjust the refusal directions so that only the component that is - # orthogonal to the good direction is subtracted during abliteration. - good_directions = F.normalize(good_means, p=2, dim=1) - projection_vector = torch.sum(refusal_directions * good_directions, dim=1) - refusal_directions = ( - refusal_directions - projection_vector.unsqueeze(1) * good_directions - ) - refusal_directions = F.normalize(refusal_directions, p=2, dim=1) + if settings.use_ara: + print() + print("Obtaining module I/O for good prompts...") + good_module_io = model.get_module_io_batched(good_prompts) + print("Obtaining module I/O for bad prompts...") + bad_module_io = model.get_module_io_batched(bad_prompts) + else: + print() + print("Calculating per-layer refusal directions...") + print("* Obtaining residuals for good prompts...") + good_residuals = model.get_residuals_batched(good_prompts) + print("* Obtaining residuals for bad prompts...") + bad_residuals = model.get_residuals_batched(bad_prompts) + + good_means = good_residuals.mean(dim=0) + bad_means = bad_residuals.mean(dim=0) + + refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1) + + if settings.orthogonalize_direction: + # Implements https://huggingface.co/blog/grimjim/projected-abliteration + # Adjust the refusal directions so that only the component that is + # orthogonal to the good direction is subtracted during abliteration. + good_directions = F.normalize(good_means, p=2, dim=1) + projection_vector = torch.sum(refusal_directions * good_directions, dim=1) + refusal_directions = ( + refusal_directions - projection_vector.unsqueeze(1) * good_directions + ) + refusal_directions = F.normalize(refusal_directions, p=2, dim=1) - analyzer = Analyzer(settings, model, good_residuals, bad_residuals) + analyzer = Analyzer(settings, model, good_residuals, bad_residuals) - if settings.print_residual_geometry: - analyzer.print_residual_geometry() + if settings.print_residual_geometry: + analyzer.print_residual_geometry() - if settings.plot_residuals: - analyzer.plot_residuals() + if settings.plot_residuals: + analyzer.plot_residuals() - # We don't need the residuals after computing refusal directions. - del good_residuals, bad_residuals, analyzer - empty_cache() + # We don't need the residuals after computing refusal directions. + del good_residuals, bad_residuals, analyzer + empty_cache() trial_index = 0 start_index = 0 @@ -495,83 +503,135 @@ def objective(trial: Trial) -> tuple[float, float]: trial_index += 1 trial.set_user_attr("index", trial_index) - direction_scope = trial.suggest_categorical( - "direction_scope", - [ - "global", - "per layer", - ], - ) - - last_layer_index = len(model.get_layers()) - 1 - - # Discrimination between "harmful" and "harmless" inputs is usually strongest - # in layers slightly past the midpoint of the layer stack. See the original - # abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis. - # - # Note that we always sample this parameter even though we only need it for - # the "global" direction scope. The reason is that multivariate TPE doesn't - # work with conditional or variable-range parameters. - direction_index = trial.suggest_float( - "direction_index", - 0.4 * last_layer_index, - 0.9 * last_layer_index, - ) - - if direction_scope == "per layer": - direction_index = None - - parameters = {} - - for component in model.get_abliterable_components(): - # The parameter ranges are based on experiments with various models - # and much wider ranges. They are not set in stone and might have to be - # adjusted for future models. - max_weight = trial.suggest_float( - f"{component}.max_weight", - 0.8, - 1.5, + if settings.use_ara: + start_layer_index = trial.suggest_int( + "start_layer_index", + 0, + len(model.get_layers()) // 2, ) - max_weight_position = trial.suggest_float( - f"{component}.max_weight_position", - 0.6 * last_layer_index, - 1.0 * last_layer_index, + end_layer_index = trial.suggest_int( + "end_layer_index", + len(model.get_layers()) // 2, + len(model.get_layers()), ) - # For sampling purposes, min_weight is expressed as a fraction of max_weight, - # again because multivariate TPE doesn't support variable-range parameters. - # The value is transformed into the actual min_weight value below. - min_weight = trial.suggest_float( - f"{component}.min_weight", + preserve_good_behavior_weight = trial.suggest_float( + "preserve_good_behavior_weight", 0.0, 1.0, ) - min_weight_distance = trial.suggest_float( - f"{component}.min_weight_distance", + steer_bad_behavior_weight = trial.suggest_float( + "steer_bad_behavior_weight", + 0.0001, 1.0, - 0.6 * last_layer_index, + log=True, + ) + overcorrect_relative_weight = trial.suggest_float( + "overcorrect_relative_weight", + 0.0, + 1.3, + ) + neighbor_count = trial.suggest_int( + "neighbor_count", + 1, + 15, + ) + + ara_parameters = ARAParameters( + start_layer_index=start_layer_index, + end_layer_index=end_layer_index, + preserve_good_behavior_weight=preserve_good_behavior_weight, + steer_bad_behavior_weight=steer_bad_behavior_weight, + overcorrect_relative_weight=overcorrect_relative_weight, + neighbor_count=neighbor_count, ) - parameters[component] = AbliterationParameters( - max_weight=max_weight, - max_weight_position=max_weight_position, - min_weight=(min_weight * max_weight), - min_weight_distance=min_weight_distance, + trial.set_user_attr("ara_parameters", asdict(ara_parameters)) + else: + direction_scope = trial.suggest_categorical( + "direction_scope", + [ + "global", + "per layer", + ], + ) + + last_layer_index = len(model.get_layers()) - 1 + + # Discrimination between "harmful" and "harmless" inputs is usually strongest + # in layers slightly past the midpoint of the layer stack. See the original + # abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis. + # + # Note that we always sample this parameter even though we only need it for + # the "global" direction scope. The reason is that multivariate TPE doesn't + # work with conditional or variable-range parameters. + direction_index = trial.suggest_float( + "direction_index", + 0.4 * last_layer_index, + 0.9 * last_layer_index, ) - trial.set_user_attr("direction_index", direction_index) - trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()}) + if direction_scope == "per layer": + direction_index = None + + parameters = {} + + for component in model.get_abliterable_components(): + # The parameter ranges are based on experiments with various models + # and much wider ranges. They are not set in stone and might have to be + # adjusted for future models. + max_weight = trial.suggest_float( + f"{component}.max_weight", + 0.8, + 1.5, + ) + max_weight_position = trial.suggest_float( + f"{component}.max_weight_position", + 0.6 * last_layer_index, + 1.0 * last_layer_index, + ) + # For sampling purposes, min_weight is expressed as a fraction of max_weight, + # again because multivariate TPE doesn't support variable-range parameters. + # The value is transformed into the actual min_weight value below. + min_weight = trial.suggest_float( + f"{component}.min_weight", + 0.0, + 1.0, + ) + min_weight_distance = trial.suggest_float( + f"{component}.min_weight_distance", + 1.0, + 0.6 * last_layer_index, + ) + + parameters[component] = AbliterationParameters( + max_weight=max_weight, + max_weight_position=max_weight_position, + min_weight=(min_weight * max_weight), + min_weight_distance=min_weight_distance, + ) + + trial.set_user_attr("direction_index", direction_index) + trial.set_user_attr( + "parameters", {k: asdict(v) for k, v in parameters.items()} + ) print() print( f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..." ) print("* Parameters:") - for name, value in get_trial_parameters(trial).items(): + for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") - print("* Resetting model...") - model.reset_model() - print("* Abliterating...") - model.abliterate(refusal_directions, direction_index, parameters) + if settings.use_ara: + print("* Reloading model...") + model.reset_model() + print("* Abliterating (Arbitrary-Rank Ablation)...") + model.ara_abliterate(good_module_io, bad_module_io, ara_parameters) + else: + print("* Resetting model...") + model.reset_model() + print("* Abliterating...") + model.abliterate(refusal_directions, direction_index, parameters) print("* Evaluating...") score, kl_divergence, refusals = evaluator.get_score() @@ -669,7 +729,7 @@ def count_completed_trials() -> int: title=( f"[Trial {trial.user_attrs['index']:>3}] " f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " - f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}" + f"{'PIQA acc_norm' if settings.use_piqa else 'KL divergence'}: {(-1 if settings.use_piqa else 1) * trial.user_attrs['kl_divergence']:.4f}" ), value=trial, ) @@ -748,19 +808,29 @@ def count_completed_trials() -> int: print() print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...") print("* Parameters:") - for name, value in get_trial_parameters(trial).items(): + for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") - print("* Resetting model...") - model.reset_model() - print("* Abliterating...") - model.abliterate( - refusal_directions, - trial.user_attrs["direction_index"], - { - k: AbliterationParameters(**v) - for k, v in trial.user_attrs["parameters"].items() - }, - ) + if settings.use_ara: + print("* Reloading model...") + model.reset_model() + print("* Abliterating (Arbitrary-Rank Ablation)...") + model.ara_abliterate( + good_module_io, + bad_module_io, + ARAParameters(**trial.user_attrs["ara_parameters"]), + ) + else: + print("* Resetting model...") + model.reset_model() + print("* Abliterating...") + model.abliterate( + refusal_directions, + trial.user_attrs["direction_index"], + { + k: AbliterationParameters(**v) + for k, v in trial.user_attrs["parameters"].items() + }, + ) while True: print() @@ -796,8 +866,12 @@ def count_completed_trials() -> int: print("Saving LoRA adapter...") model.model.save_pretrained(save_directory) else: - print("Saving merged model...") - merged_model = model.get_merged_model() + if settings.use_ara: + print("Saving model...") + merged_model = model.model + else: + print("Saving merged model...") + merged_model = model.get_merged_model() merged_model.save_pretrained(save_directory) del merged_model empty_cache() @@ -851,8 +925,12 @@ def count_completed_trials() -> int: token=token, ) else: - print("Uploading merged model...") - merged_model = model.get_merged_model() + if settings.use_ara: + print("Uploading model...") + merged_model = model.model + else: + print("Uploading merged model...") + merged_model = model.get_merged_model() merged_model.push_to_hub( repo_id, private=private, @@ -891,6 +969,14 @@ def count_completed_trials() -> int: card.data.tags.append("uncensored") card.data.tags.append("decensored") card.data.tags.append("abliterated") + if settings.use_ara: + card.data.tags.append("ara") + elif ( + settings.orthogonalize_direction + and settings.row_normalization + == RowNormalization.FULL + ): + card.data.tags.append("mpoa") card.text = ( get_readme_intro( settings, @@ -966,6 +1052,7 @@ def count_completed_trials() -> int: hflm = HFLM( pretrained=model.model, # ty:ignore[invalid-argument-type] tokenizer=model.tokenizer, # ty:ignore[invalid-argument-type] + batch_size="auto", ) table = Table() @@ -989,7 +1076,6 @@ def get_results() -> dict[str, Any]: results = lm_eval.simple_evaluate( model=hflm, tasks=[benchmark.task], - batch_size="auto", ) return results["results"][benchmark.task] diff --git a/src/heretic/model.py b/src/heretic/model.py index 55afa267..108e9dfd 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -4,7 +4,7 @@ import math from contextlib import suppress from dataclasses import dataclass -from typing import Any, Type, cast +from typing import Any, Callable, Type, TypeAlias, cast import bitsandbytes as bnb import torch @@ -14,6 +14,8 @@ from peft.tuners.lora.layer import Linear from torch import FloatTensor, LongTensor, Tensor from torch.nn import Module, ModuleList +from torch.optim import LBFGS +from torch.utils.hooks import RemovableHandle from transformers import ( AutoModelForCausalLM, AutoModelForImageTextToText, @@ -30,7 +32,7 @@ ) from .config import QuantizationMethod, RowNormalization, Settings -from .utils import Prompt, batchify, empty_cache, print +from .utils import Prompt, batchify, empty_cache, mean_distances_to_knn, print def get_model_class( @@ -52,6 +54,23 @@ class AbliterationParameters: min_weight_distance: float +@dataclass +class ARAParameters: + start_layer_index: int + end_layer_index: int + preserve_good_behavior_weight: float + steer_bad_behavior_weight: float + overcorrect_relative_weight: float + neighbor_count: int + + +# The list contains one element per layer. +# Each element maps from the component name to a (possibly sparse) mapping +# from the module index to an (input, output) tuple containing the I/O +# tensors of shape (prompt, component). +ModuleIO: TypeAlias = list[dict[str, dict[int, tuple[Tensor, Tensor]]]] + + class Model: model: PreTrainedModel | PeftModel tokenizer: PreTrainedTokenizerBase @@ -142,7 +161,8 @@ def __init__(self, settings: Settings): if self.model is None: raise Exception("Failed to load model with all configured dtypes.") - self._apply_lora() + if not settings.use_ara: + self._apply_lora() # LoRA B matrices are initialized to zero by default in PEFT, # so we don't need to do anything manually. @@ -288,7 +308,11 @@ def reset_model(self): performs full model reload with quantization config. """ current_model = getattr(self.model.config, "name_or_path", None) - if current_model == self.settings.model and not self.needs_reload: + if ( + current_model == self.settings.model + and not self.needs_reload + and not self.settings.use_ara + ): # Reset LoRA adapters to zero (identity transformation) for name, module in self.model.named_modules(): if "lora_B" in name and hasattr(module, "weight"): @@ -317,7 +341,8 @@ def reset_model(self): **extra_kwargs, ) - self._apply_lora() + if not self.settings.use_ara: + self._apply_lora() self.needs_reload = False @@ -341,6 +366,9 @@ def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]: modules = {} def try_add(component: str, module: Any): + if component not in self.settings.target_components: + return + # Only add if it's a proper nn.Module (PEFT can wrap these with LoRA) if isinstance(module, Module): if component not in modules: @@ -541,6 +569,105 @@ def abliterate( weight_A.data = lora_A.to(weight_A.dtype) weight_B.data = lora_B.to(weight_B.dtype) + def ara_abliterate( + self, + good_module_io: ModuleIO, + bad_module_io: ModuleIO, + parameters: ARAParameters, + ): + for layer_index in range( + parameters.start_layer_index, + parameters.end_layer_index, + ): + for component, modules in self.get_layer_modules(layer_index).items(): + for module_index, module in enumerate(modules): + # See above for a (partial) justification of this cast. + module = cast(Linear, module) + matrix = module.weight + + row_norms = LA.vector_norm(matrix, dim=1, keepdim=True).detach() + + # Helper function for reparameterization (row-norm preservation constraint). + def get_matrix() -> Tensor: + if self.settings.row_normalization == RowNormalization.FULL: + # See https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration + return row_norms * F.normalize(matrix, p=2, dim=1) + else: + return matrix + + good_input, good_output = good_module_io[layer_index][component][ + module_index + ] + bad_input, bad_output = bad_module_io[layer_index][component][ + module_index + ] + + good_input = good_input.to(matrix.device) + good_output = good_output.to(matrix.device) + bad_input = bad_input.to(matrix.device) + bad_output = bad_output.to(matrix.device) + + def objective(matrix: Tensor) -> Tensor: + new_good_output = good_input @ matrix.T + new_bad_output = bad_input @ matrix.T + + # The outputs for "good" prompts should change as little as possible. + preserve_good_behavior = ( + (new_good_output - good_output) ** 2 + ).mean() + + steer_bad_behavior = ( + # Pull the outputs for "bad" prompts towards + # the original outputs for "good" prompts. + mean_distances_to_knn( + new_bad_output, + good_output, + parameters.neighbor_count, + ).mean() + # Push the outputs for "bad" prompts away from + # the original outputs for "bad" prompts. + # In combination with the above, this overcorrects + # away from the original residuals, which results + # in stronger steering that can overcome more complex + # refusal mechanisms. + + parameters.overcorrect_relative_weight + * -mean_distances_to_knn( + new_bad_output, + bad_output, + parameters.neighbor_count, + ).mean() + ) + + return ( + parameters.preserve_good_behavior_weight + * preserve_good_behavior + + parameters.steer_bad_behavior_weight * steer_bad_behavior + ) + + optimizer = LBFGS( + [matrix], + lr=1.0, + max_iter=20, # Number of internal iterations per step, *not* the number of steps. + history_size=10, + line_search_fn="strong_wolfe", + ) + + def closure() -> Tensor: + optimizer.zero_grad() + loss = objective(get_matrix()) + loss.backward() + return loss + + # Convergence usually happens within 2-3 steps, so this is more than enough. + for step in range(5): + loss = optimizer.step(closure) + # print( + # f"\\[{layer_index}/{component}/{module_index}] Step: {step}, Loss: {loss.item():.6f}" + # ) + + with torch.no_grad(): + matrix.copy_(get_matrix()) + def generate( self, prompts: list[Prompt], @@ -675,6 +802,132 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: return torch.cat(residuals, dim=0) + def get_module_io( + self, + prompts: list[Prompt], + ) -> ModuleIO: + # The list contains one element per layer. + # Each element maps from the component name to a (possibly sparse) mapping + # from the module index to an (input, output) tuple containing the I/O + # tensors of shape (prompt, component). + module_io: ModuleIO = [] + + def get_hook( + layer_index: int, + component: str, + module_index: int, + ) -> Callable[[Module, tuple[Tensor, ...], Tensor], None]: + def hook( + module: Module, + inputs: tuple[Tensor, ...], + outputs: Tensor, + ) -> None: + if len(module_io) == layer_index: + # First invocation of the hook for this layer. + module_io.append({}) + + # Layers are invoked in order during inference, + # so this should always hold. + assert len(module_io) == layer_index + 1 + + if component not in module_io[layer_index]: + module_io[layer_index][component] = {} + + # Each module should be invoked at most once per inference step. + assert module_index not in module_io[layer_index][component] + + # inputs[0] and outputs have shape (prompt, position, component), + # so this extracts the input/output at the end of each prompt. + # Move to CPU to decouple from device assignments, which can + # change between model reloads in multi-GPU configurations. + input = inputs[0][:, -1, :].detach().clone().cpu() + output = outputs[:, -1, :].detach().clone().cpu() + + # The modules associated with a component (e.g. expert MLPs) + # are not necessarily invoked in order, nor are all of them + # necessarily invoked in each inference step, so we cannot + # use a list here. + module_io[layer_index][component][module_index] = (input, output) + + return hook + + hook_handles: list[RemovableHandle] = [] + + for layer_index in range(len(self.get_layers())): + for component, modules in self.get_layer_modules(layer_index).items(): + for module_index, module in enumerate(modules): + hook_handles.append( + module.register_forward_hook( + get_hook(layer_index, component, module_index) + ) + ) + + self.generate(prompts, max_new_tokens=1) + + for hook_handle in hook_handles: + hook_handle.remove() + + return module_io + + def get_module_io_batched( + self, + prompts: list[Prompt], + ) -> ModuleIO: + # Aggregating batch results is more complicated for module I/O + # than for other get_*_batched methods, because the structure of the results + # might differ between batches, as whether individual modules activate + # can depend on the prompt (in particular for MoE models). + # In practice, inhomogeneous results should be very rare, but to be fully + # generic, this logic is required. + module_io_batches: list[ModuleIO] = [ + self.get_module_io(batch) + for batch in batchify(prompts, self.settings.batch_size) + ] + + module_io: ModuleIO = [] + + for layer_index in range(len(self.get_layers())): + module_io.append({}) + + for module_io_batch in module_io_batches: + for component, io_map in module_io_batch[layer_index].items(): + if component not in module_io[layer_index]: + module_io[layer_index][component] = {} + + for module_index in io_map: + if module_index not in module_io[layer_index][component]: + # This is a placeholder; the actual aggregation happens below. + # We need to iterate over the batches twice because we don't + # know in advance which components and module indices are present. + module_io[layer_index][component][module_index] = ( + torch.empty(0), + torch.empty(0), + ) + + for component, io_map in module_io[layer_index].items(): + for module_index in io_map: + inputs_outputs = [ + module_io_batch[layer_index][component][module_index] + for module_io_batch in module_io_batches + if component in module_io_batch[layer_index] + and module_index in module_io_batch[layer_index][component] + ] + input = torch.cat( + [input_output[0] for input_output in inputs_outputs], + dim=0, + ) + output = torch.cat( + [input_output[1] for input_output in inputs_outputs], + dim=0, + ) + + # The key already exists, and replacing existing values + # in a dictionary while iterating over the same dictionary + # is safe in Python. + module_io[layer_index][component][module_index] = (input, output) + + return module_io + # We work with logprobs rather than probabilities for numerical stability # when computing the KL divergence. def get_logprobs(self, prompts: list[Prompt]) -> Tensor: diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 288ca0ff..96896ad3 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -25,8 +25,9 @@ from psutil import Process from questionary import Choice, Style from rich.console import Console +from torch import Tensor -from .config import DatasetSpecification, Settings +from .config import DatasetSpecification, RowNormalization, Settings print = Console(highlight=False).print @@ -234,6 +235,14 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]: return [items[i : i + batch_size] for i in range(0, len(items), batch_size)] +# For each vector in the 2D-tensor `a`, computes the mean Euclidean distance +# to the `k` nearest neighbors of the vector among the vectors in the 2D-tensor `b`. +def mean_distances_to_knn(a: Tensor, b: Tensor, k: int) -> Tensor: + distances = torch.cdist(a, b) + nearest_distances, _ = distances.topk(k, dim=1, largest=False) + return nearest_distances.mean(1) + + def empty_cache(): # Collecting garbage is not an idempotent operation, and to avoid OOM errors, # gc.collect() has to be called both before and after emptying the backend cache. @@ -256,19 +265,46 @@ def empty_cache(): gc.collect() -def get_trial_parameters(trial: Trial) -> dict[str, str]: - params = {} +def get_trial_parameters(settings: Settings, trial: Trial) -> dict[str, str]: + if settings.use_ara: + parameters = trial.user_attrs["ara_parameters"] - direction_index = trial.user_attrs["direction_index"] - params["direction_index"] = ( - "per layer" if (direction_index is None) else f"{direction_index:.2f}" - ) + return { + name: (f"{value:.4f}" if isinstance(value, float) else f"{value}") + for name, value in parameters.items() + } + else: + params = {} - for component, parameters in trial.user_attrs["parameters"].items(): - for name, value in parameters.items(): - params[f"{component}.{name}"] = f"{value:.2f}" + direction_index = trial.user_attrs["direction_index"] + params["direction_index"] = ( + "per layer" if (direction_index is None) else f"{direction_index:.2f}" + ) + + for component, parameters in trial.user_attrs["parameters"].items(): + for name, value in parameters.items(): + params[f"{component}.{name}"] = f"{value:.2f}" - return params + return params + + +def get_method_description(settings: Settings) -> str: + if settings.use_ara: + return ( + " with the [Arbitrary-Rank Ablation (ARA)](https://github.com/p-e-w/heretic/pull/211) method" + + ( + " (with row-norm preservation)" + if settings.row_normalization == RowNormalization.FULL + else "" + ) + ) + elif ( + settings.orthogonalize_direction + and settings.row_normalization == RowNormalization.FULL + ): + return " with a variant of the [Magnitude-Preserving Orthogonal Ablation (MPOA)](https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration) method" + else: + return "" def get_readme_intro( @@ -285,7 +321,9 @@ def get_readme_intro( return f"""# This is a decensored version of { model_link - }, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")} + }, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")}{ + get_method_description(settings) + } ## Abliteration parameters @@ -295,7 +333,7 @@ def get_readme_intro( chr(10).join( [ f"| **{name}** | {value} |" - for name, value in get_trial_parameters(trial).items() + for name, value in get_trial_parameters(settings, trial).items() ] ) } @@ -304,7 +342,10 @@ def get_readme_intro( | Metric | This model | Original model ({model_link}) | | :----- | :--------: | :---------------------------: | -| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* | +| **{"PIQA acc_norm" if settings.use_piqa else "KL divergence"}** | { + (-1 if settings.use_piqa else 1) * trial.user_attrs["kl_divergence"]:.4f} | { + "*Unknown*" if settings.use_piqa else "0 *(by definition)*" + } | | **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{ len(bad_prompts) } |