Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce peak memory during FLUX model load #7564

Merged
merged 1 commit into from
Jan 16, 2025
Merged

Conversation

RyanJDick
Copy link
Collaborator

Summary

Prior to this change, there were several cases where we initialized the weights of a FLUX model before loading its state dict (and, to make things worse, in some cases the weights were in float32). This PR fixes a handful of these cases. (I think I found all instances for the FLUX family of models.)

Related Issues / Discussions

QA Instructions

I tested that that model loading still works and that there is no virtual memory reservation on model initialization for the following models:

  • FLUX VAE
  • Full T5 Encoder
  • Full FLUX checkpoint
  • GGUF FLUX checkpoint

Merge Plan

No special instructions.

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • Documentation added / updated (if applicable)
  • Updated What's New copy (if doing a release after this PR)

…models to ensure that the model is initialized on the meta device prior to loading the state dict into it. This helps to keep peak memory down.
@github-actions github-actions bot added python PRs that change python files backend PRs that change backend files labels Jan 16, 2025
@RyanJDick RyanJDick merged commit 0abb5ea into main Jan 16, 2025
15 checks passed
@RyanJDick RyanJDick deleted the ryan/lower-virtual-memory branch January 16, 2025 23:47
RyanJDick added a commit that referenced this pull request Jan 17, 2025
## Summary

This PR adds a `keep_ram_copy_of_weights` config option the default (and
legacy) behavior is `true`. The tradeoffs for this setting are as
follows:
- `keep_ram_copy_of_weights: true`: Faster model switching and LoRA
patching.
- `keep_ram_copy_of_weights: false`: Lower average RAM load (may not
help significantly with peak RAM).

## Related Issues / Discussions

- Helps with #7563
- The Low-VRAM docs are updated to include this feature in
#7566

## QA Instructions

- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.

- [x] Smoke test CPU-only and MPS with default configs.

## Merge Plan

- [x] Merge #7564 first and
change target branch.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files python PRs that change python files
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants