Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 35 additions & 53 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def _is_help_invocation() -> bool:
import torch
import torch.nn.functional as F
import transformers
from huggingface_hub import ModelCard, ModelCardData
from lm_eval.models.huggingface import HFLM
from optuna import Trial, TrialPruned
from optuna.exceptions import ExperimentalWarning
Expand All @@ -61,10 +60,10 @@ def _is_help_invocation() -> bool:
from .config import QuantizationMethod
from .evaluator import Evaluator
from .model import AbliterationParameters, Model, get_model_class
from .model_card_utils import get_model_card
from .system import empty_cache, get_accelerator_info
from .utils import (
format_duration,
get_readme_intro,
get_trial_parameters,
is_hf_path,
load_prompts,
Expand Down Expand Up @@ -127,31 +126,26 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
)
print()

strategy = prompt_select(
"How do you want to proceed?",
choices=[
Choice(
title="Merge LoRA into full model"
+ (
""
if settings.quantization == QuantizationMethod.NONE
else " (requires sufficient RAM)"
),
value="merge",
),
Choice(
title="Cancel",
value="cancel",
strategy = prompt_select(
"How do you want to proceed?",
choices=[
Choice(
title="Merge LoRA into full model"
+ (
""
if settings.quantization == QuantizationMethod.NONE
else " (requires sufficient RAM)"
),
],
)

if strategy == "cancel":
return None
value="merge",
),
Choice(
title="Save LoRA adapter only (can be merged later)",
value="adapter",
),
],
)

return strategy
else:
return "merge"
return strategy


def run():
Expand Down Expand Up @@ -784,6 +778,7 @@ def count_completed_trials() -> int:
save_directory,
max_shard_size=settings.max_shard_size,
)
card = get_model_card(settings, trial, "", True)
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
Expand All @@ -794,6 +789,10 @@ def count_completed_trials() -> int:
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
card = get_model_card(settings, trial, "", False)

if card is not None:
card.save(f"{save_directory}/README.md")

print(f"Model saved to [bold]{save_directory}[/].")

Expand Down Expand Up @@ -887,6 +886,12 @@ def count_completed_trials() -> int:
max_shard_size=settings.max_shard_size,
token=token,
)
card = get_model_card(
settings,
trial,
reproducibility_information,
True,
)
else:
print("Uploading merged model...")
merged_model = model.get_merged_model()
Expand All @@ -904,37 +909,14 @@ def count_completed_trials() -> int:
token=token,
)

if is_hf_path(settings.model):
card = ModelCard.load(settings.model)
else:
card_path = (
Path(settings.model)
/ huggingface_hub.constants.REPOCARD_NAME
card = get_model_card(
settings,
trial,
reproducibility_information,
False,
)
if card_path.exists():
card = ModelCard.load(card_path)
else:
card = None

if card is not None:
if card.data is None:
card.data = ModelCardData()
if card.data.tags is None:
card.data.tags = []
card.data.tags.append("heretic")
card.data.tags.append("uncensored")
card.data.tags.append("decensored")
card.data.tags.append("abliterated")
if reproducibility_information != "none":
card.data.tags.append("reproducible")
card.text = (
get_readme_intro(
settings,
trial,
reproducibility_information != "none",
)
+ card.text
)
card.push_to_hub(repo_id, token=token)

if reproducibility_information != "none":
Expand Down
123 changes: 123 additions & 0 deletions src/heretic/model_card_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors

from pathlib import Path

import huggingface_hub
from huggingface_hub import ModelCard, ModelCardData
from optuna import Trial

from .config import RowNormalization, Settings
from .system import (
get_heretic_version_info,
)
from .utils import get_trial_parameters, is_hf_path


def get_readme_intro(
settings: Settings,
trial: Trial,
contains_reproducibility_information: bool,
is_lora: bool,
) -> str:
if is_hf_path(settings.model):
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
else:
# Hide the path, which may contain private information.
model_link = "a model"

version_info = get_heretic_version_info()

if contains_reproducibility_information:
reproducibility_instructions = """
> [!TIP]
> **This model is reproducible!**
>
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
"""
else:
reproducibility_instructions = ""

return f"""# This is a decensored {"adapter" if is_lora else "version"} of {
model_link
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
{reproducibility_instructions}
## Abliteration parameters

| Parameter | Value |
| :-------- | :---: |
{
chr(10).join(
[
f"| **{name}** | {value} |"
for name, value in get_trial_parameters(trial).items()
]
)
}

## Performance

| Metric | This model | Original model ({model_link}) |
| :----- | :--------: | :---------------------------: |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
trial.user_attrs["base_refusals"]
}/{trial.user_attrs["n_bad_prompts"]} |

-----

"""


def get_model_card(
settings: Settings,
trial: Trial,
reproducibility_information: str,
is_lora: bool,
) -> ModelCard | None:
Comment thread
anrp marked this conversation as resolved.
# If the model path exists locally and includes the
# card, use it directly. If the model path doesn't
# exist locally, it can be assumed to be a model
# hosted on the Hugging Face Hub, in which case
# we can retrieve the model card.
if is_hf_path(settings.model):
card = ModelCard.load(settings.model)
else:
card_path = Path(settings.model) / huggingface_hub.constants.REPOCARD_NAME
if card_path.exists():
card = ModelCard.load(card_path)
else:
card = None

if card is not None:
if card.data is None:
card.data = ModelCardData()
if card.data.tags is None:
card.data.tags = []
card.data.tags.append("heretic")
card.data.tags.append("uncensored")
card.data.tags.append("decensored")
card.data.tags.append("abliterated")
if (
settings.orthogonalize_direction
and settings.row_normalization == RowNormalization.FULL
):
card.data.tags.append("mpoa")
if reproducibility_information != "none":
card.data.tags.append("reproducible")

if is_hf_path(settings.model):
card.data.base_model = settings.model
card.data.base_model_relation = "adapter" if is_lora else "finetuned"

card.text = (
get_readme_intro(
settings,
trial,
reproducibility_information != "none",
is_lora,
)
+ card.text
)

return card
54 changes: 0 additions & 54 deletions src/heretic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,60 +272,6 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
return params


def get_readme_intro(
settings: Settings,
trial: Trial,
contains_reproducibility_information: bool,
) -> str:
if is_hf_path(settings.model):
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
else:
# Hide the path, which may contain private information.
model_link = "a model"

version_info = get_heretic_version_info()

if contains_reproducibility_information:
reproducibility_instructions = """
> [!TIP]
> **This model is reproducible!**
>
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
"""
else:
reproducibility_instructions = ""

return f"""# This is a decensored version of {
model_link
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
{reproducibility_instructions}
## Abliteration parameters

| Parameter | Value |
| :-------- | :---: |
{
chr(10).join(
[
f"| **{name}** | {value} |"
for name, value in get_trial_parameters(trial).items()
]
)
}

## Performance

| Metric | This model | Original model ({model_link}) |
| :----- | :--------: | :---------------------------: |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
trial.user_attrs["base_refusals"]
}/{trial.user_attrs["n_bad_prompts"]} |

-----

"""


def generate_config_toml(settings: Settings) -> str:
"""Serializes the full Settings object to TOML."""

Expand Down