Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ class Settings(BaseSettings):
exclude=True,
)

reproduce: str | None = Field(
Comment thread
p-e-w marked this conversation as resolved.
default=None,
description=(
"If this path or URL to a reproduce.json file is set, load reproduction information "
"from that file, and attempt to reproduce the abliterated model it originated from."
),
exclude=True,
)

dtypes: list[str] = Field(
default=[
# In practice, "auto" almost always means bfloat16.
Expand Down
13 changes: 11 additions & 2 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _is_help_invocation() -> bool:
from .config import QuantizationMethod
from .evaluator import Evaluator
from .model import AbliterationParameters, Model, get_model_class
from .reproduce import collect_reproducibles
from .reproduce import collect_reproducibles, load_reproduction_information
from .system import empty_cache, get_accelerator_info
from .utils import (
format_duration,
Expand Down Expand Up @@ -175,6 +175,7 @@ def run():
len(sys.argv) > 1
# Heretic is being invoked in standard (model processing) mode.
and "--collect-reproducibles" not in sys.argv
and "--reproduce" not in sys.argv
# No model has been explicitly provided.
and "--model" not in sys.argv
# The last argument is a parameter value rather than a flag (such as "--help").
Expand All @@ -185,7 +186,9 @@ def run():

# Work around the "model" argument being required
# when Heretic is invoked in a non-processing mode.
if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv:
if (
"--collect-reproducibles" in sys.argv or "--reproduce" in sys.argv
) and "--model" not in sys.argv:
sys.argv.extend(["--model", ""])

try:
Expand All @@ -208,6 +211,12 @@ def run():
collect_reproducibles(settings.collect_reproducibles)
return

if settings.reproduce is not None:
print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...")
reproduction_information = load_reproduction_information(settings.reproduce)
print(reproduction_information)
return

if settings.seed is None:
settings.seed = random.randint(0, 2**32 - 1)

Expand Down
19 changes: 19 additions & 0 deletions src/heretic/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors

import json
import shutil
from pathlib import Path
from typing import Any
from urllib.request import urlopen

from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
Expand Down Expand Up @@ -81,3 +84,19 @@ def collect_reproducibles(path: str):
print(f"Found: [bold]{found}[/] files")
print(f"Downloaded: [bold]{downloaded}[/] files")
print(f"Already stored: [bold]{found - downloaded}[/] files")


def load_reproduction_information(path: str) -> dict[str, Any]:
if path.lower().startswith(("http://", "https://")):
# The path is a URL on the web.

# Obtain raw download URL.
path = path.replace("/blob/", "/raw/") # Hugging Face, GitHub
path = path.replace("/src/branch/", "/raw/branch/") # Codeberg
Comment thread
p-e-w marked this conversation as resolved.

json_str = urlopen(path).read().decode("utf-8")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is recommended to use a context manager with urlopen to ensure the connection is properly closed. Additionally, adding a timeout prevents the process from hanging indefinitely on network issues.

Suggested change
json_str = urlopen(path).read().decode("utf-8")
with urlopen(path, timeout=10) as response:
json_str = response.read().decode("utf-8")

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

else:
# The path is (assumed to be) a local file system path.
json_str = Path(path).read_text(encoding="utf-8")

return json.loads(json_str)
Loading