diff --git a/src/heretic/main.py b/src/heretic/main.py index e25dd813..19e66c60 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -43,7 +43,6 @@ def _is_help_invocation() -> bool: import torch import torch.nn.functional as F import transformers -from huggingface_hub import ModelCard, ModelCardData from lm_eval.models.huggingface import HFLM from optuna import Trial, TrialPruned from optuna.exceptions import ExperimentalWarning @@ -61,10 +60,10 @@ def _is_help_invocation() -> bool: from .config import QuantizationMethod from .evaluator import Evaluator from .model import AbliterationParameters, Model, get_model_class +from .model_card_utils import get_model_card from .system import empty_cache, get_accelerator_info from .utils import ( format_duration, - get_readme_intro, get_trial_parameters, is_hf_path, load_prompts, @@ -127,31 +126,26 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: ) print() - strategy = prompt_select( - "How do you want to proceed?", - choices=[ - Choice( - title="Merge LoRA into full model" - + ( - "" - if settings.quantization == QuantizationMethod.NONE - else " (requires sufficient RAM)" - ), - value="merge", - ), - Choice( - title="Cancel", - value="cancel", + strategy = prompt_select( + "How do you want to proceed?", + choices=[ + Choice( + title="Merge LoRA into full model" + + ( + "" + if settings.quantization == QuantizationMethod.NONE + else " (requires sufficient RAM)" ), - ], - ) - - if strategy == "cancel": - return None + value="merge", + ), + Choice( + title="Save LoRA adapter only (can be merged later)", + value="adapter", + ), + ], + ) - return strategy - else: - return "merge" + return strategy def run(): @@ -784,6 +778,7 @@ def count_completed_trials() -> int: save_directory, max_shard_size=settings.max_shard_size, ) + card = get_model_card(settings, trial, "", True) else: print("Saving merged model...") merged_model = model.get_merged_model() @@ -794,6 +789,10 @@ def count_completed_trials() -> int: del merged_model empty_cache() model.tokenizer.save_pretrained(save_directory) + card = get_model_card(settings, trial, "", False) + + if card is not None: + card.save(f"{save_directory}/README.md") print(f"Model saved to [bold]{save_directory}[/].") @@ -887,6 +886,12 @@ def count_completed_trials() -> int: max_shard_size=settings.max_shard_size, token=token, ) + card = get_model_card( + settings, + trial, + reproducibility_information, + True, + ) else: print("Uploading merged model...") merged_model = model.get_merged_model() @@ -904,37 +909,14 @@ def count_completed_trials() -> int: token=token, ) - if is_hf_path(settings.model): - card = ModelCard.load(settings.model) - else: - card_path = ( - Path(settings.model) - / huggingface_hub.constants.REPOCARD_NAME + card = get_model_card( + settings, + trial, + reproducibility_information, + False, ) - if card_path.exists(): - card = ModelCard.load(card_path) - else: - card = None if card is not None: - if card.data is None: - card.data = ModelCardData() - if card.data.tags is None: - card.data.tags = [] - card.data.tags.append("heretic") - card.data.tags.append("uncensored") - card.data.tags.append("decensored") - card.data.tags.append("abliterated") - if reproducibility_information != "none": - card.data.tags.append("reproducible") - card.text = ( - get_readme_intro( - settings, - trial, - reproducibility_information != "none", - ) - + card.text - ) card.push_to_hub(repo_id, token=token) if reproducibility_information != "none": diff --git a/src/heretic/model_card_utils.py b/src/heretic/model_card_utils.py new file mode 100644 index 00000000..237bc50d --- /dev/null +++ b/src/heretic/model_card_utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: AGPL-3.0-or-later +# Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors + +from pathlib import Path + +import huggingface_hub +from huggingface_hub import ModelCard, ModelCardData +from optuna import Trial + +from .config import RowNormalization, Settings +from .system import ( + get_heretic_version_info, +) +from .utils import get_trial_parameters, is_hf_path + + +def get_readme_intro( + settings: Settings, + trial: Trial, + contains_reproducibility_information: bool, + is_lora: bool, +) -> str: + if is_hf_path(settings.model): + model_link = f"[{settings.model}](https://huggingface.co/{settings.model})" + else: + # Hide the path, which may contain private information. + model_link = "a model" + + version_info = get_heretic_version_info() + + if contains_reproducibility_information: + reproducibility_instructions = """ +> [!TIP] +> **This model is reproducible!** +> +> See the [README](reproduce/README.md) in the `reproduce` directory for more information. +""" + else: + reproducibility_instructions = "" + + return f"""# This is a decensored {"adapter" if is_lora else "version"} of { + model_link + }, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version} +{reproducibility_instructions} +## Abliteration parameters + +| Parameter | Value | +| :-------- | :---: | +{ + chr(10).join( + [ + f"| **{name}** | {value} |" + for name, value in get_trial_parameters(trial).items() + ] + ) + } + +## Performance + +| Metric | This model | Original model ({model_link}) | +| :----- | :--------: | :---------------------------: | +| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* | +| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | { + trial.user_attrs["base_refusals"] + }/{trial.user_attrs["n_bad_prompts"]} | + +----- + +""" + + +def get_model_card( + settings: Settings, + trial: Trial, + reproducibility_information: str, + is_lora: bool, +) -> ModelCard | None: + # If the model path exists locally and includes the + # card, use it directly. If the model path doesn't + # exist locally, it can be assumed to be a model + # hosted on the Hugging Face Hub, in which case + # we can retrieve the model card. + if is_hf_path(settings.model): + card = ModelCard.load(settings.model) + else: + card_path = Path(settings.model) / huggingface_hub.constants.REPOCARD_NAME + if card_path.exists(): + card = ModelCard.load(card_path) + else: + card = None + + if card is not None: + if card.data is None: + card.data = ModelCardData() + if card.data.tags is None: + card.data.tags = [] + card.data.tags.append("heretic") + card.data.tags.append("uncensored") + card.data.tags.append("decensored") + card.data.tags.append("abliterated") + if ( + settings.orthogonalize_direction + and settings.row_normalization == RowNormalization.FULL + ): + card.data.tags.append("mpoa") + if reproducibility_information != "none": + card.data.tags.append("reproducible") + + if is_hf_path(settings.model): + card.data.base_model = settings.model + card.data.base_model_relation = "adapter" if is_lora else "finetuned" + + card.text = ( + get_readme_intro( + settings, + trial, + reproducibility_information != "none", + is_lora, + ) + + card.text + ) + + return card diff --git a/src/heretic/utils.py b/src/heretic/utils.py index e688c5de..01995748 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -272,60 +272,6 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]: return params -def get_readme_intro( - settings: Settings, - trial: Trial, - contains_reproducibility_information: bool, -) -> str: - if is_hf_path(settings.model): - model_link = f"[{settings.model}](https://huggingface.co/{settings.model})" - else: - # Hide the path, which may contain private information. - model_link = "a model" - - version_info = get_heretic_version_info() - - if contains_reproducibility_information: - reproducibility_instructions = """ -> [!TIP] -> **This model is reproducible!** -> -> See the [README](reproduce/README.md) in the `reproduce` directory for more information. -""" - else: - reproducibility_instructions = "" - - return f"""# This is a decensored version of { - model_link - }, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version} -{reproducibility_instructions} -## Abliteration parameters - -| Parameter | Value | -| :-------- | :---: | -{ - chr(10).join( - [ - f"| **{name}** | {value} |" - for name, value in get_trial_parameters(trial).items() - ] - ) - } - -## Performance - -| Metric | This model | Original model ({model_link}) | -| :----- | :--------: | :---------------------------: | -| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* | -| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | { - trial.user_attrs["base_refusals"] - }/{trial.user_attrs["n_bad_prompts"]} | - ------ - -""" - - def generate_config_toml(settings: Settings) -> str: """Serializes the full Settings object to TOML."""