-
Notifications
You must be signed in to change notification settings - Fork 2.3k
feat: add configurable residual processing to reduce peak VRAM usage #239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
f902fd5
4309c38
30685de
89d45e1
67bf79b
a7c8f09
ba17216
4a37686
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -445,13 +445,33 @@ def run(): | |
|
|
||
| print() | ||
| print("Calculating per-layer refusal directions...") | ||
| print("* Obtaining residuals for good prompts...") | ||
| good_residuals = model.get_residuals_batched(good_prompts) | ||
| print("* Obtaining residuals for bad prompts...") | ||
| bad_residuals = model.get_residuals_batched(bad_prompts) | ||
|
|
||
| good_means = good_residuals.mean(dim=0) | ||
| bad_means = bad_residuals.mean(dim=0) | ||
| needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals | ||
|
|
||
| good_residuals = None | ||
| bad_residuals = None | ||
|
|
||
| if needs_full_residuals: | ||
| print("* Obtaining residuals for good prompts...") | ||
| good_residuals = model.get_residuals_batched(good_prompts) | ||
| print("* Obtaining residuals for bad prompts...") | ||
| bad_residuals = model.get_residuals_batched(bad_prompts) | ||
|
|
||
| good_means = good_residuals.mean(dim=0) | ||
| bad_means = bad_residuals.mean(dim=0) | ||
|
|
||
| analyzer = Analyzer(settings, model, good_residuals, bad_residuals) | ||
|
|
||
| if settings.print_residual_geometry: | ||
| analyzer.print_residual_geometry() | ||
|
|
||
| if settings.plot_residuals: | ||
| analyzer.plot_residuals() | ||
| else: | ||
| print("* Obtaining residual mean for good prompts...") | ||
| good_means = model.get_residuals_mean(good_prompts) | ||
| print("* Obtaining residual mean for bad prompts...") | ||
| bad_means = model.get_residuals_mean(bad_prompts) | ||
|
|
||
| refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1) | ||
|
|
||
|
|
@@ -466,14 +486,6 @@ def run(): | |
| ) | ||
| refusal_directions = F.normalize(refusal_directions, p=2, dim=1) | ||
|
|
||
| analyzer = Analyzer(settings, model, good_residuals, bad_residuals) | ||
|
|
||
| if settings.print_residual_geometry: | ||
| analyzer.print_residual_geometry() | ||
|
|
||
| if settings.plot_residuals: | ||
| analyzer.plot_residuals() | ||
|
|
||
| # We don't need the residuals after computing refusal directions. | ||
| del good_residuals, bad_residuals, analyzer | ||
| empty_cache() | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two lines are supposed to solve the problem you are describing. Why do you think they don't work?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those lines handle post-analysis cleanup, but they do not reduce peak memory during accumulation. The peak occurs earlier in |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -632,6 +632,9 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: | |
| max_new_tokens=1, | ||
| output_hidden_states=True, | ||
| return_dict_in_generate=True, | ||
| # KV cache is unnecessary here because we only need the hidden states | ||
| # for the first generated token. | ||
| use_cache=False, | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This applies to the logprobs also. |
||
| ) | ||
|
|
||
| # This cast is valid because GenerateDecoderOnlyOutput is the return type | ||
|
|
@@ -665,7 +668,11 @@ def get_residuals(self, prompts: list[Prompt]) -> Tensor: | |
| dim=2, | ||
| keepdim=True, | ||
| ) | ||
| return torch.clamp(residuals, -thresholds, thresholds) | ||
| residuals = torch.clamp(residuals, -thresholds, thresholds) | ||
|
|
||
| if self.settings.offload_outputs_to_cpu: | ||
| residuals = residuals.cpu() | ||
| empty_cache() | ||
|
|
||
| return residuals | ||
|
|
||
|
|
@@ -677,6 +684,29 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: | |
|
|
||
| return torch.cat(residuals, dim=0) | ||
|
|
||
| def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: | ||
| running_sum = None | ||
| total_count = 0 | ||
|
|
||
| for batch in batchify(prompts, self.settings.batch_size): | ||
| batch_residuals = self.get_residuals(batch) | ||
|
|
||
| # Accumulate in high precision on CPU to reduce peak VRAM usage. | ||
| batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu() | ||
|
|
||
| if running_sum is None: | ||
| running_sum = batch_sum | ||
| else: | ||
| running_sum += batch_sum | ||
|
|
||
| total_count += batch_residuals.shape[0] | ||
|
|
||
| assert running_sum is not None, ( | ||
|
magiccodingman marked this conversation as resolved.
Outdated
|
||
| "No prompts were provided for residual averaging." | ||
| ) | ||
|
|
||
| return (running_sum / total_count).to(torch.float32) | ||
|
|
||
| # We work with logprobs rather than probabilities for numerical stability | ||
| # when computing the KL divergence. | ||
| def get_logprobs(self, prompts: list[Prompt]) -> Tensor: | ||
|
|
@@ -698,7 +728,15 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor: | |
| logits = cast(tuple[FloatTensor], outputs.scores)[0] | ||
|
|
||
| # The returned tensor has shape (prompt, token). | ||
| return F.log_softmax(logits, dim=-1) | ||
| logprobs = F.log_softmax(logits, dim=-1) | ||
|
|
||
| del outputs | ||
|
|
||
| if self.settings.offload_outputs_to_cpu: | ||
| logprobs = logprobs.cpu() | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it really make sense to offload logprobs? Typical vocabulary sizes are around 250k today, so even at 32 bits per logprob, that's just 8 Megabytes. Which is essentially a rounding error even on very small GPUs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Individually small, but across batches and repeated evaluations they can add up and contribute to allocator pressure. Offloading keeps VRAM usage more predictable during longer runs. |
||
| empty_cache() | ||
|
|
||
| return logprobs | ||
|
|
||
| def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: | ||
| logprobs = [] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.