From f902fd5c86278862eba7a9065ba12cdf199f413b Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Fri, 27 Mar 2026 11:57:04 -0400 Subject: [PATCH 1/8] refactor residual memory optimizations --- config.default.toml | 6 ++++++ src/heretic/config.py | 8 ++++++++ src/heretic/main.py | 40 +++++++++++++++++++++++++----------- src/heretic/model.py | 47 +++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 87 insertions(+), 14 deletions(-) diff --git a/config.default.toml b/config.default.toml index abfa0fc7..5b0688ff 100644 --- a/config.default.toml +++ b/config.default.toml @@ -133,6 +133,12 @@ refusal_markers = [ # System prompt to use when prompting the model. system_prompt = "You are a helpful assistant." +# Move intermediate analysis tensors (such as residuals and logprobs) +# to CPU memory as soon as possible to reduce peak VRAM usage. +# This lowers peak VRAM usage during residual analysis and evaluation, +# but may slightly reduce performance due to host/device transfers. +offload_outputs_to_cpu = true + # Dataset of prompts that tend to not result in refusals (used for calculating refusal directions). [good_prompts] dataset = "mlabonne/harmless_alpaca" diff --git a/src/heretic/config.py b/src/heretic/config.py index 8b70499b..21358a48 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -132,6 +132,14 @@ class Settings(BaseSettings): description="Number of input sequences to process in parallel (0 = auto).", ) + offload_outputs_to_cpu: bool = Field( + default=False, + description=( + "Whether to move intermediate analysis tensors (such as residuals and logprobs) " + "to CPU memory as soon as possible to reduce peak VRAM usage." + ), + ) + max_batch_size: int = Field( default=128, description="Maximum batch size to try when automatically determining the optimal batch size.", diff --git a/src/heretic/main.py b/src/heretic/main.py index 37233817..41ef5cb6 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -445,13 +445,30 @@ def run(): print() print("Calculating per-layer refusal directions...") - print("* Obtaining residuals for good prompts...") - good_residuals = model.get_residuals_batched(good_prompts) - print("* Obtaining residuals for bad prompts...") - bad_residuals = model.get_residuals_batched(bad_prompts) - good_means = good_residuals.mean(dim=0) - bad_means = bad_residuals.mean(dim=0) + needs_full_residuals = ( + settings.print_residual_geometry or settings.plot_residuals + ) + + good_residuals = None + bad_residuals = None + analyzer = None + + if needs_full_residuals: + print("* Obtaining residuals for good prompts...") + good_residuals = model.get_residuals_batched(good_prompts) + print("* Obtaining residuals for bad prompts...") + bad_residuals = model.get_residuals_batched(bad_prompts) + + good_means = good_residuals.mean(dim=0) + bad_means = bad_residuals.mean(dim=0) + + analyzer = Analyzer(settings, model, good_residuals, bad_residuals) + else: + print("* Obtaining residual mean for good prompts...") + good_means = model.get_residuals_mean(good_prompts) + print("* Obtaining residual mean for bad prompts...") + bad_means = model.get_residuals_mean(bad_prompts) refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1) @@ -466,13 +483,12 @@ def run(): ) refusal_directions = F.normalize(refusal_directions, p=2, dim=1) - analyzer = Analyzer(settings, model, good_residuals, bad_residuals) - - if settings.print_residual_geometry: - analyzer.print_residual_geometry() + if analyzer is not None: + if settings.print_residual_geometry: + analyzer.print_residual_geometry() - if settings.plot_residuals: - analyzer.plot_residuals() + if settings.plot_residuals: + analyzer.plot_residuals() # We don't need the residuals after computing refusal directions. del good_residuals, bad_residuals, analyzer diff --git a/src/heretic/model.py b/src/heretic/model.py index c2bda929..19b49f13 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -632,6 +632,9 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: max_new_tokens=1, output_hidden_states=True, return_dict_in_generate=True, + # KV cache is unnecessary here because we only need the hidden states + # for the first generated token. + use_cache=False, ) # This cast is valid because GenerateDecoderOnlyOutput is the return type @@ -655,6 +658,9 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: # problems during calculations involving residual vectors. residuals = residuals.to(torch.float32) + del hidden_states + del outputs + if 0 <= self.settings.winsorization_quantile < 1: # Apply symmetric winsorization to each layer of the per-prompt residuals. abs_residuals = torch.abs(residuals) @@ -665,7 +671,11 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: dim=2, keepdim=True, ) - return torch.clamp(residuals, -thresholds, thresholds) + residuals = torch.clamp(residuals, -thresholds, thresholds) + + if self.settings.offload_outputs_to_cpu: + residuals = residuals.cpu() + empty_cache() return residuals @@ -676,6 +686,30 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: residuals.append(self.get_residuals(batch)) return torch.cat(residuals, dim=0) + + def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: + running_sum = None + total_count = 0 + + for batch in batchify(prompts, self.settings.batch_size): + batch_residuals = self.get_residuals(batch) + + # Accumulate in high precision on CPU to reduce peak VRAM usage. + batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu() + + if running_sum is None: + running_sum = batch_sum + else: + running_sum += batch_sum + + total_count += batch_residuals.shape[0] + + del batch_residuals + del batch_sum + + assert running_sum is not None, "No prompts were provided for residual averaging." + + return (running_sum / total_count).to(torch.float32) # We work with logprobs rather than probabilities for numerical stability # when computing the KL divergence. @@ -698,7 +732,16 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor: logits = cast(tuple[FloatTensor], outputs.scores)[0] # The returned tensor has shape (prompt, token). - return F.log_softmax(logits, dim=-1) + logprobs = F.log_softmax(logits, dim=-1) + + del logits + del outputs + + if self.settings.offload_outputs_to_cpu: + logprobs = logprobs.cpu() + empty_cache() + + return logprobs def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: logprobs = [] From 4309c3879fd6b7960ce8101dc831191cc201b655 Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Fri, 27 Mar 2026 12:04:15 -0400 Subject: [PATCH 2/8] formatting --- src/heretic/main.py | 4 +--- src/heretic/model.py | 6 ++++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index 41ef5cb6..261ae454 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -446,9 +446,7 @@ def run(): print() print("Calculating per-layer refusal directions...") - needs_full_residuals = ( - settings.print_residual_geometry or settings.plot_residuals - ) + needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals good_residuals = None bad_residuals = None diff --git a/src/heretic/model.py b/src/heretic/model.py index 19b49f13..777ab9eb 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -686,7 +686,7 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: residuals.append(self.get_residuals(batch)) return torch.cat(residuals, dim=0) - + def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: running_sum = None total_count = 0 @@ -707,7 +707,9 @@ def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: del batch_residuals del batch_sum - assert running_sum is not None, "No prompts were provided for residual averaging." + assert running_sum is not None, ( + "No prompts were provided for residual averaging." + ) return (running_sum / total_count).to(torch.float32) From 30685de233ea1eb1fa333cec63ace435496521d1 Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Mon, 13 Apr 2026 15:21:16 -0400 Subject: [PATCH 3/8] Fixed config.py positioning and default --- src/heretic/config.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/heretic/config.py b/src/heretic/config.py index 21358a48..7fe7bb45 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -132,14 +132,6 @@ class Settings(BaseSettings): description="Number of input sequences to process in parallel (0 = auto).", ) - offload_outputs_to_cpu: bool = Field( - default=False, - description=( - "Whether to move intermediate analysis tensors (such as residuals and logprobs) " - "to CPU memory as soon as possible to reduce peak VRAM usage." - ), - ) - max_batch_size: int = Field( default=128, description="Maximum batch size to try when automatically determining the optimal batch size.", @@ -354,6 +346,14 @@ class Settings(BaseSettings): description="System prompt to use when prompting the model.", ) + offload_outputs_to_cpu: bool = Field( + default=True, + description=( + "Whether to move intermediate analysis tensors (such as residuals and logprobs) " + "to CPU memory as soon as possible to reduce peak VRAM usage." + ), + ) + good_prompts: DatasetSpecification = Field( default=DatasetSpecification( dataset="mlabonne/harmless_alpaca", From 89d45e159cf8791bbff972ecd4bb9f3007e27211 Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Mon, 13 Apr 2026 15:23:57 -0400 Subject: [PATCH 4/8] fixed analyzier declaration in main.py --- src/heretic/main.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index 261ae454..3b08f647 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -450,7 +450,6 @@ def run(): good_residuals = None bad_residuals = None - analyzer = None if needs_full_residuals: print("* Obtaining residuals for good prompts...") @@ -462,6 +461,12 @@ def run(): bad_means = bad_residuals.mean(dim=0) analyzer = Analyzer(settings, model, good_residuals, bad_residuals) + + if settings.print_residual_geometry: + analyzer.print_residual_geometry() + + if settings.plot_residuals: + analyzer.plot_residuals() else: print("* Obtaining residual mean for good prompts...") good_means = model.get_residuals_mean(good_prompts) @@ -481,13 +486,6 @@ def run(): ) refusal_directions = F.normalize(refusal_directions, p=2, dim=1) - if analyzer is not None: - if settings.print_residual_geometry: - analyzer.print_residual_geometry() - - if settings.plot_residuals: - analyzer.plot_residuals() - # We don't need the residuals after computing refusal directions. del good_residuals, bad_residuals, analyzer empty_cache() From 67bf79b0d6a9328b65043deb901a4feac5d746d7 Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Tue, 14 Apr 2026 16:45:51 -0400 Subject: [PATCH 5/8] removing del statements --- src/heretic/model.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index 777ab9eb..bf710ad8 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -658,9 +658,6 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: # problems during calculations involving residual vectors. residuals = residuals.to(torch.float32) - del hidden_states - del outputs - if 0 <= self.settings.winsorization_quantile < 1: # Apply symmetric winsorization to each layer of the per-prompt residuals. abs_residuals = torch.abs(residuals) @@ -704,9 +701,6 @@ def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: total_count += batch_residuals.shape[0] - del batch_residuals - del batch_sum - assert running_sum is not None, ( "No prompts were provided for residual averaging." ) @@ -735,8 +729,7 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor: # The returned tensor has shape (prompt, token). logprobs = F.log_softmax(logits, dim=-1) - - del logits + del outputs if self.settings.offload_outputs_to_cpu: From a7c8f09ada611261c37bcb034b0ac3463f2370bc Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Thu, 16 Apr 2026 18:36:46 -0400 Subject: [PATCH 6/8] ruff --- src/heretic/model.py | 2 +- uv.lock | 17 ++++------------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index bf710ad8..c488b36b 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -729,7 +729,7 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor: # The returned tensor has shape (prompt, token). logprobs = F.log_softmax(logits, dim=-1) - + del outputs if self.settings.offload_outputs_to_cpu: diff --git a/uv.lock b/uv.lock index 09cf60ed..6c4c861c 100644 --- a/uv.lock +++ b/uv.lock @@ -876,7 +876,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/6a/33d1702184d94106d3cdd7bfb788e19723206fce152e303473ca3b946c7b/greenlet-3.3.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6f8496d434d5cb2dce025773ba5597f71f5410ae499d5dd9533e0653258cdb3d", size = 273658, upload-time = "2025-12-04T14:23:37.494Z" }, { url = "https://files.pythonhosted.org/packages/d6/b7/2b5805bbf1907c26e434f4e448cd8b696a0b71725204fa21a211ff0c04a7/greenlet-3.3.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b96dc7eef78fd404e022e165ec55327f935b9b52ff355b067eb4a0267fc1cffb", size = 574810, upload-time = "2025-12-04T14:50:04.154Z" }, { url = "https://files.pythonhosted.org/packages/94/38/343242ec12eddf3d8458c73f555c084359883d4ddc674240d9e61ec51fd6/greenlet-3.3.0-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:73631cd5cccbcfe63e3f9492aaa664d278fda0ce5c3d43aeda8e77317e38efbd", size = 586248, upload-time = "2025-12-04T14:57:39.35Z" }, - { url = "https://files.pythonhosted.org/packages/f0/d0/0ae86792fb212e4384041e0ef8e7bc66f59a54912ce407d26a966ed2914d/greenlet-3.3.0-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b299a0cb979f5d7197442dccc3aee67fce53500cd88951b7e6c35575701c980b", size = 597403, upload-time = "2025-12-04T15:07:10.831Z" }, { url = "https://files.pythonhosted.org/packages/b6/a8/15d0aa26c0036a15d2659175af00954aaaa5d0d66ba538345bd88013b4d7/greenlet-3.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dee147740789a4632cace364816046e43310b59ff8fb79833ab043aefa72fd5", size = 586910, upload-time = "2025-12-04T14:25:59.705Z" }, { url = "https://files.pythonhosted.org/packages/e1/9b/68d5e3b7ccaba3907e5532cf8b9bf16f9ef5056a008f195a367db0ff32db/greenlet-3.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:39b28e339fc3c348427560494e28d8a6f3561c8d2bcf7d706e1c624ed8d822b9", size = 1547206, upload-time = "2025-12-04T15:04:21.027Z" }, { url = "https://files.pythonhosted.org/packages/66/bd/e3086ccedc61e49f91e2cfb5ffad9d8d62e5dc85e512a6200f096875b60c/greenlet-3.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b3c374782c2935cc63b2a27ba8708471de4ad1abaa862ffdb1ef45a643ddbb7d", size = 1613359, upload-time = "2025-12-04T14:27:26.548Z" }, @@ -884,7 +883,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/cb/48e964c452ca2b92175a9b2dca037a553036cb053ba69e284650ce755f13/greenlet-3.3.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e29f3018580e8412d6aaf5641bb7745d38c85228dacf51a73bd4e26ddf2a6a8e", size = 274908, upload-time = "2025-12-04T14:23:26.435Z" }, { url = "https://files.pythonhosted.org/packages/28/da/38d7bff4d0277b594ec557f479d65272a893f1f2a716cad91efeb8680953/greenlet-3.3.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a687205fb22794e838f947e2194c0566d3812966b41c78709554aa883183fb62", size = 577113, upload-time = "2025-12-04T14:50:05.493Z" }, { url = "https://files.pythonhosted.org/packages/3c/f2/89c5eb0faddc3ff014f1c04467d67dee0d1d334ab81fadbf3744847f8a8a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4243050a88ba61842186cb9e63c7dfa677ec146160b0efd73b855a3d9c7fcf32", size = 590338, upload-time = "2025-12-04T14:57:41.136Z" }, - { url = "https://files.pythonhosted.org/packages/80/d7/db0a5085035d05134f8c089643da2b44cc9b80647c39e93129c5ef170d8f/greenlet-3.3.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:670d0f94cd302d81796e37299bcd04b95d62403883b24225c6b5271466612f45", size = 601098, upload-time = "2025-12-04T15:07:11.898Z" }, { url = "https://files.pythonhosted.org/packages/dc/a6/e959a127b630a58e23529972dbc868c107f9d583b5a9f878fb858c46bc1a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb3a8ec3db4a3b0eb8a3c25436c2d49e3505821802074969db017b87bc6a948", size = 590206, upload-time = "2025-12-04T14:26:01.254Z" }, { url = "https://files.pythonhosted.org/packages/48/60/29035719feb91798693023608447283b266b12efc576ed013dd9442364bb/greenlet-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2de5a0b09eab81fc6a382791b995b1ccf2b172a9fec934747a7a23d2ff291794", size = 1550668, upload-time = "2025-12-04T15:04:22.439Z" }, { url = "https://files.pythonhosted.org/packages/0a/5f/783a23754b691bfa86bd72c3033aa107490deac9b2ef190837b860996c9f/greenlet-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4449a736606bd30f27f8e1ff4678ee193bc47f6ca810d705981cfffd6ce0d8c5", size = 1615483, upload-time = "2025-12-04T14:27:28.083Z" }, @@ -892,7 +890,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, - { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -900,7 +897,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, - { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -908,7 +904,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/7c/f0a6d0ede2c7bf092d00bc83ad5bafb7e6ec9b4aab2fbdfa6f134dc73327/greenlet-3.3.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:60c2ef0f578afb3c8d92ea07ad327f9a062547137afe91f38408f08aacab667f", size = 275671, upload-time = "2025-12-04T14:23:05.267Z" }, { url = "https://files.pythonhosted.org/packages/44/06/dac639ae1a50f5969d82d2e3dd9767d30d6dbdbab0e1a54010c8fe90263c/greenlet-3.3.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a5d554d0712ba1de0a6c94c640f7aeba3f85b3a6e1f2899c11c2c0428da9365", size = 646360, upload-time = "2025-12-04T14:50:10.026Z" }, { url = "https://files.pythonhosted.org/packages/e0/94/0fb76fe6c5369fba9bf98529ada6f4c3a1adf19e406a47332245ef0eb357/greenlet-3.3.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3a898b1e9c5f7307ebbde4102908e6cbfcb9ea16284a3abe15cab996bee8b9b3", size = 658160, upload-time = "2025-12-04T14:57:45.41Z" }, - { url = "https://files.pythonhosted.org/packages/93/79/d2c70cae6e823fac36c3bbc9077962105052b7ef81db2f01ec3b9bf17e2b/greenlet-3.3.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dcd2bdbd444ff340e8d6bdf54d2f206ccddbb3ccfdcd3c25bf4afaa7b8f0cf45", size = 671388, upload-time = "2025-12-04T15:07:15.789Z" }, { url = "https://files.pythonhosted.org/packages/b8/14/bab308fc2c1b5228c3224ec2bf928ce2e4d21d8046c161e44a2012b5203e/greenlet-3.3.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5773edda4dc00e173820722711d043799d3adb4f01731f40619e07ea2750b955", size = 660166, upload-time = "2025-12-04T14:26:05.099Z" }, { url = "https://files.pythonhosted.org/packages/4b/d2/91465d39164eaa0085177f61983d80ffe746c5a1860f009811d498e7259c/greenlet-3.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ac0549373982b36d5fd5d30beb8a7a33ee541ff98d2b502714a09f1169f31b55", size = 1615193, upload-time = "2025-12-04T15:04:27.041Z" }, { url = "https://files.pythonhosted.org/packages/42/1b/83d110a37044b92423084d52d5d5a3b3a73cafb51b547e6d7366ff62eff1/greenlet-3.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d198d2d977460358c3b3a4dc844f875d1adb33817f0613f663a656f463764ccc", size = 1683653, upload-time = "2025-12-04T14:27:32.366Z" }, @@ -916,7 +911,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/66/bd6317bc5932accf351fc19f177ffba53712a202f9df10587da8df257c7e/greenlet-3.3.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d6ed6f85fae6cdfdb9ce04c9bf7a08d666cfcfb914e7d006f44f840b46741931", size = 282638, upload-time = "2025-12-04T14:25:20.941Z" }, { url = "https://files.pythonhosted.org/packages/30/cf/cc81cb030b40e738d6e69502ccbd0dd1bced0588e958f9e757945de24404/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9125050fcf24554e69c4cacb086b87b3b55dc395a8b3ebe6487b045b2614388", size = 651145, upload-time = "2025-12-04T14:50:11.039Z" }, { url = "https://files.pythonhosted.org/packages/9c/ea/1020037b5ecfe95ca7df8d8549959baceb8186031da83d5ecceff8b08cd2/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87e63ccfa13c0a0f6234ed0add552af24cc67dd886731f2261e46e241608bee3", size = 654236, upload-time = "2025-12-04T14:57:47.007Z" }, - { url = "https://files.pythonhosted.org/packages/69/cc/1e4bae2e45ca2fa55299f4e85854606a78ecc37fead20d69322f96000504/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2662433acbca297c9153a4023fe2161c8dcfdcc91f10433171cf7e7d94ba2221", size = 662506, upload-time = "2025-12-04T15:07:16.906Z" }, { url = "https://files.pythonhosted.org/packages/57/b9/f8025d71a6085c441a7eaff0fd928bbb275a6633773667023d19179fe815/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3c6e9b9c1527a78520357de498b0e709fb9e2f49c3a513afd5a249007261911b", size = 653783, upload-time = "2025-12-04T14:26:06.225Z" }, { url = "https://files.pythonhosted.org/packages/f6/c7/876a8c7a7485d5d6b5c6821201d542ef28be645aa024cfe1145b35c120c1/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:286d093f95ec98fdd92fcb955003b8a3d054b4e2cab3e2707a5039e7b50520fd", size = 1614857, upload-time = "2025-12-04T15:04:28.484Z" }, { url = "https://files.pythonhosted.org/packages/4f/dc/041be1dff9f23dac5f48a43323cd0789cb798342011c19a248d9c9335536/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9", size = 1676034, upload-time = "2025-12-04T14:27:33.531Z" }, @@ -961,8 +955,6 @@ research = [ { name = "geom-median" }, { name = "imageio" }, { name = "matplotlib" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pacmap" }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -983,13 +975,12 @@ requires-dist = [ { name = "hf-transfer", specifier = "~=0.1" }, { name = "huggingface-hub", specifier = "~=1.7" }, { name = "imageio", marker = "extra == 'research'", specifier = "~=2.37" }, - { name = "immutabledict", specifier = ">=4.3.1" }, + { name = "immutabledict", specifier = "~=4.3" }, { name = "kernels", specifier = "~=0.12" }, - { name = "langdetect", specifier = ">=1.0.9" }, - { name = "lm-eval", extras = ["hf"], specifier = "~=0.4.11" }, + { name = "langdetect", specifier = "~=1.0" }, + { name = "lm-eval", extras = ["hf"], specifier = "~=0.4" }, { name = "matplotlib", marker = "extra == 'research'", specifier = "~=3.10" }, - { name = "numpy", specifier = ">=2.2.6" }, - { name = "numpy", marker = "extra == 'research'", specifier = "~=2.2" }, + { name = "numpy", specifier = "~=2.2" }, { name = "optuna", specifier = "~=4.7" }, { name = "pacmap", marker = "extra == 'research'", specifier = "~=0.8" }, { name = "peft", specifier = "~=0.18" }, From ba17216a59a75014b4e7762b7c4ee3e2d1786347 Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Thu, 16 Apr 2026 18:56:35 -0400 Subject: [PATCH 7/8] small updates --- src/heretic/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index c488b36b..3ebf95b0 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -685,6 +685,9 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: return torch.cat(residuals, dim=0) def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: + if not prompts: + raise ValueError("prompts must not be empty") + running_sum = None total_count = 0 @@ -701,10 +704,6 @@ def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: total_count += batch_residuals.shape[0] - assert running_sum is not None, ( - "No prompts were provided for residual averaging." - ) - return (running_sum / total_count).to(torch.float32) # We work with logprobs rather than probabilities for numerical stability @@ -717,6 +716,7 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor: max_new_tokens=1, output_scores=True, return_dict_in_generate=True, + use_cache=False, ) # This cast is valid because GenerateDecoderOnlyOutput is the return type From 4a376862bd8c61f9cf0b2bc83d403e25cd2935ef Mon Sep 17 00:00:00 2001 From: magiccodingman Date: Fri, 17 Apr 2026 10:07:38 -0400 Subject: [PATCH 8/8] ty moveback ish --- src/heretic/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/heretic/model.py b/src/heretic/model.py index 3ebf95b0..475e34a7 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -704,6 +704,8 @@ def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: total_count += batch_residuals.shape[0] + assert running_sum is not None + return (running_sum / total_count).to(torch.float32) # We work with logprobs rather than probabilities for numerical stability