-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanse state dict of shared pointers before save #159
Conversation
👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review. |
Can you please test that |
Specifically the issue I'm seeing seems to relate to loading a model which has been saved after oneshot |
src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
Show resolved
Hide resolved
src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
Outdated
Show resolved
Hide resolved
Using these three scripts to test reload_stories_normal.pyimport os, shutil
from llmcompressor.transformers import SparseAutoModelForCausalLM
output_dir = "./my_model"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# base
model = SparseAutoModelForCausalLM.from_pretrained("Xenova/llama2.c-stories15M", device_map="auto", torch_dtype="auto")
# save
model.save_pretrained(
output_dir,
save_compressed=True,
safe_serialization=False, # False:=pytorch_model.bin, True:=model.safetensors
)
# load normal
model = SparseAutoModelForCausalLM.from_pretrained(
output_dir, device_map="auto"
)
print(model) reload_stories_oneshot.pyimport os, shutil
from llmcompressor.core import create_session
from llmcompressor.transformers import (
SparseAutoModelForCausalLM,
oneshot,
)
output_dir = "./oneshot_out"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml"
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
splits = {"calibration": "train[:10%]"}
# base
model = SparseAutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
)
# save oneshot
with create_session():
oneshot(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
)
# load oneshot
model = SparseAutoModelForCausalLM.from_pretrained(
output_dir, device_map="auto"
)
print(model) reload_stories_distill.pyimport os, shutil
from llmcompressor.core import create_session
from llmcompressor.transformers import (
SparseAutoModelForCausalLM,
oneshot, train
)
output_dir = "./distill_out"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
dataset = "open_platypus"
concatenate_data = False
splits = "train[:50%]"
max_steps = 50
num_calibration_samples = 64
recipe_str = "tests/llmcompressor/transformers/finetune/test_finetune_recipe.yaml"
# base
model = SparseAutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
)
distill_teacher = SparseAutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
)
# distill
with create_session():
train(
model=model,
distill_teacher=distill_teacher,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
max_steps=max_steps,
)
# load
model = SparseAutoModelForCausalLM.from_pretrained(
output_dir, device_map="auto"
) |
Loading a model which has removed tensors leads to loading failures right now. I think we'd need to use to adapt safetensors' load_model function on the loading side in order to fully support this method |
I'm still investigating as to why load_model, save_model is not the default pathway. Perhaps metadata loss |
I think the problem with this approach is that it ignores the root problem, namely that tensors are being shared when they shouldn't be. I've opened an issue on HF transformers here |
SUMMARY:
Adapts code from https://github.com/huggingface/safetensors/blob/5db3b92c76ba293a0715b916c16b113c0b3551e9/bindings/python/py_src/safetensors/torch.py#L155 to cleanse state dict of shared pointers before saving.
Also check: https://huggingface.co/docs/safetensors/en/torch_shared_tensors
TEST PLAN:
The tests now pass on cpu and gpu
Also fixes the shared tensors issue seen in ex_trl_distillation