Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,22 @@ class Settings(BaseSettings):
),
)

target_components: list[str] = Field(
default=["attn.o_proj", "mlp.down_proj"],
description=(
"List of component names to target for abliteration. "
'Currently supported values are "attn.o_proj" and "mlp.down_proj".'
),
)

use_ara: bool = Field(
default=True,
description=(
"Whether to use Arbitrary-Rank Ablation (ARA), an abliteration method based on matrix optimization, "
"instead of traditional directional ablation."
),
)
Comment thread
p-e-w marked this conversation as resolved.

orthogonalize_direction: bool = Field(
default=False,
description=(
Expand Down
279 changes: 169 additions & 110 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ def run():
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
)

# We don't need gradients as we only do inference.
torch.set_grad_enabled(False)
if not settings.use_ara:
# We don't need gradients as we only do inference.
torch.set_grad_enabled(False)

# While determining the optimal batch size, we will try many different batch sizes,
# resulting in many computation graphs being compiled. Raising the limit (default = 8)
Expand Down Expand Up @@ -411,40 +412,47 @@ def run():
evaluator.get_score()
return

print()
print("Calculating per-layer refusal directions...")
print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)

good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)

refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)

if settings.orthogonalize_direction:
# Implements https://huggingface.co/blog/grimjim/projected-abliteration
# Adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
good_directions = F.normalize(good_means, p=2, dim=1)
projection_vector = torch.sum(refusal_directions * good_directions, dim=1)
refusal_directions = (
refusal_directions - projection_vector.unsqueeze(1) * good_directions
)
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
if settings.use_ara:
print()
print("Obtaining module I/O for good prompts...")
good_module_io = model.get_module_io_batched(good_prompts)
print("Obtaining module I/O for bad prompts...")
bad_module_io = model.get_module_io_batched(bad_prompts)
else:
print()
print("Calculating per-layer refusal directions...")
print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)

good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)

refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)

if settings.orthogonalize_direction:
# Implements https://huggingface.co/blog/grimjim/projected-abliteration
# Adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
good_directions = F.normalize(good_means, p=2, dim=1)
projection_vector = torch.sum(refusal_directions * good_directions, dim=1)
refusal_directions = (
refusal_directions - projection_vector.unsqueeze(1) * good_directions
)
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)

analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)

if settings.print_residual_geometry:
analyzer.print_residual_geometry()
if settings.print_residual_geometry:
analyzer.print_residual_geometry()

if settings.plot_residuals:
analyzer.plot_residuals()
if settings.plot_residuals:
analyzer.plot_residuals()

# We don't need the residuals after computing refusal directions.
del good_residuals, bad_residuals, analyzer
empty_cache()
# We don't need the residuals after computing refusal directions.
del good_residuals, bad_residuals, analyzer
empty_cache()

trial_index = 0
start_index = 0
Expand All @@ -455,83 +463,114 @@ def objective(trial: Trial) -> tuple[float, float]:
trial_index += 1
trial.set_user_attr("index", trial_index)

direction_scope = trial.suggest_categorical(
"direction_scope",
[
"global",
"per layer",
],
)

last_layer_index = len(model.get_layers()) - 1

# Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
#
# Note that we always sample this parameter even though we only need it for
# the "global" direction scope. The reason is that multivariate TPE doesn't
# work with conditional or variable-range parameters.
direction_index = trial.suggest_float(
"direction_index",
0.4 * last_layer_index,
0.9 * last_layer_index,
)

if direction_scope == "per layer":
direction_index = None

parameters = {}

for component in model.get_abliterable_components():
# The parameter ranges are based on experiments with various models
# and much wider ranges. They are not set in stone and might have to be
# adjusted for future models.
max_weight = trial.suggest_float(
f"{component}.max_weight",
0.8,
1.5,
if settings.use_ara:
start_layer_index = trial.suggest_int(
"start_layer_index",
0,
len(model.get_layers()) // 2,
)
max_weight_position = trial.suggest_float(
f"{component}.max_weight_position",
0.6 * last_layer_index,
1.0 * last_layer_index,
end_layer_index = trial.suggest_int(
"end_layer_index",
len(model.get_layers()) // 2,
len(model.get_layers()),
)
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
# again because multivariate TPE doesn't support variable-range parameters.
# The value is transformed into the actual min_weight value below.
min_weight = trial.suggest_float(
f"{component}.min_weight",
0.0,
optimization_balance = trial.suggest_float(
"optimization_balance",
-1.0,
1.0,
)
Comment thread
p-e-w marked this conversation as resolved.
min_weight_distance = trial.suggest_float(
f"{component}.min_weight_distance",
1.0,
0.6 * last_layer_index,
else:
direction_scope = trial.suggest_categorical(
"direction_scope",
[
"global",
"per layer",
],
)

parameters[component] = AbliterationParameters(
max_weight=max_weight,
max_weight_position=max_weight_position,
min_weight=(min_weight * max_weight),
min_weight_distance=min_weight_distance,
last_layer_index = len(model.get_layers()) - 1

# Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
#
# Note that we always sample this parameter even though we only need it for
# the "global" direction scope. The reason is that multivariate TPE doesn't
# work with conditional or variable-range parameters.
direction_index = trial.suggest_float(
"direction_index",
0.4 * last_layer_index,
0.9 * last_layer_index,
)

trial.set_user_attr("direction_index", direction_index)
trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()})
if direction_scope == "per layer":
direction_index = None

parameters = {}

for component in model.get_abliterable_components():
# The parameter ranges are based on experiments with various models
# and much wider ranges. They are not set in stone and might have to be
# adjusted for future models.
max_weight = trial.suggest_float(
f"{component}.max_weight",
0.8,
1.5,
)
max_weight_position = trial.suggest_float(
f"{component}.max_weight_position",
0.6 * last_layer_index,
1.0 * last_layer_index,
)
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
# again because multivariate TPE doesn't support variable-range parameters.
# The value is transformed into the actual min_weight value below.
min_weight = trial.suggest_float(
f"{component}.min_weight",
0.0,
1.0,
)
min_weight_distance = trial.suggest_float(
f"{component}.min_weight_distance",
1.0,
0.6 * last_layer_index,
)

parameters[component] = AbliterationParameters(
max_weight=max_weight,
max_weight_position=max_weight_position,
min_weight=(min_weight * max_weight),
min_weight_distance=min_weight_distance,
)

trial.set_user_attr("direction_index", direction_index)
trial.set_user_attr(
"parameters", {k: asdict(v) for k, v in parameters.items()}
)

print()
print(
f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..."
)
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
for name, value in get_trial_parameters(settings, trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(refusal_directions, direction_index, parameters)
if settings.use_ara:
print("* Reloading model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation)...")
model.ara_abliterate(
good_module_io,
bad_module_io,
start_layer_index,
end_layer_index,
optimization_balance,
)
else:
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(refusal_directions, direction_index, parameters)
print("* Evaluating...")
score, kl_divergence, refusals = evaluator.get_score()

Expand Down Expand Up @@ -708,19 +747,31 @@ def count_completed_trials() -> int:
print()
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
for name, value in get_trial_parameters(settings, trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(
refusal_directions,
trial.user_attrs["direction_index"],
{
k: AbliterationParameters(**v)
for k, v in trial.user_attrs["parameters"].items()
},
)
if settings.use_ara:
print("* Reloading model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation)...")
model.ara_abliterate(
good_module_io,
bad_module_io,
trial.params["start_layer_index"],
trial.params["end_layer_index"],
trial.params["optimization_balance"],
)
else:
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(
refusal_directions,
trial.user_attrs["direction_index"],
{
k: AbliterationParameters(**v)
for k, v in trial.user_attrs["parameters"].items()
},
)

while True:
print()
Expand Down Expand Up @@ -755,8 +806,12 @@ def count_completed_trials() -> int:
print("Saving LoRA adapter...")
model.model.save_pretrained(save_directory)
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
if settings.use_ara:
print("Saving model...")
merged_model = model.model
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
Expand Down Expand Up @@ -808,8 +863,12 @@ def count_completed_trials() -> int:
token=token,
)
else:
print("Uploading merged model...")
merged_model = model.get_merged_model()
if settings.use_ara:
print("Uploading model...")
merged_model = model.model
else:
print("Uploading merged model...")
merged_model = model.get_merged_model()
merged_model.push_to_hub(
repo_id,
private=private,
Expand Down
Loading
Loading