Reduce peak memory during FLUX model load #7564
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Prior to this change, there were several cases where we initialized the weights of a FLUX model before loading its state dict (and, to make things worse, in some cases the weights were in float32). This PR fixes a handful of these cases. (I think I found all instances for the FLUX family of models.)
Related Issues / Discussions
QA Instructions
I tested that that model loading still works and that there is no virtual memory reservation on model initialization for the following models:
Merge Plan
No special instructions.
Checklist
What's New
copy (if doing a release after this PR)