Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Prepare KD Models when Saving #174

Merged
merged 46 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b201b03
WIP
Sep 12, 2024
add7d32
load offload state dict
Sep 12, 2024
78088f5
Merge branch 'kylesayrs/fix-offloaded-saving' into kylesayrs/pickle-c…
Sep 12, 2024
cf68112
use mixin
Sep 13, 2024
264f4ae
remove files
Sep 13, 2024
94678de
restore session mixin
Sep 13, 2024
4624c46
restore session mixin
Sep 13, 2024
9b054fa
restore
Sep 13, 2024
4bec795
restore
Sep 13, 2024
375c074
remove unnecessary import
Sep 13, 2024
e8aaf1e
remove breakpoint
Sep 13, 2024
40a0425
Merge branch 'main' into kylesayrs/fix-offloaded-saving
kylesayrs Sep 13, 2024
64c3834
Merge branch 'main' into kylesayrs/fix-offloaded-saving
kylesayrs Sep 13, 2024
dbfd400
Merge branch 'main' into kylesayrs/pickle-checkpoints
kylesayrs Sep 13, 2024
9d413ec
capture return value, apply style
Sep 13, 2024
e3a024e
restore
Sep 13, 2024
0272a77
restore
Sep 13, 2024
860cf71
restore
Sep 13, 2024
b11a2f2
WIP
Sep 17, 2024
c141859
implement prepare kd for saving
Sep 23, 2024
07081b1
test reloading
Sep 23, 2024
4929997
add reloading to tests
Sep 23, 2024
31a73a8
Merge branch 'main' into kylesayrs/pickle-checkpoints
Sep 23, 2024
32619d1
Merge remote-tracking branch 'origin/main' into kylesayrs/pickle-chec…
Sep 23, 2024
27f6fcc
add wait for everyone after finalizing
Sep 23, 2024
3adf0ea
revert to previous, save for other pr
Sep 24, 2024
570189d
add test
Sep 24, 2024
f52d685
remove merge duplication
Sep 24, 2024
388abd1
prepare to fix tie_word_embeddings
Sep 24, 2024
cc62178
add full tests
Sep 25, 2024
0b990a7
comment out failing tests, point to next pr
Sep 25, 2024
048f3f8
Merge branch 'main' into kylesayrs/pickle-checkpoints
kylesayrs Sep 25, 2024
8c76a1d
Merge branch 'main' into kylesayrs/fix-offloaded-saving
kylesayrs Sep 25, 2024
201d482
apply style
Sep 25, 2024
71b16b2
Merge branch 'kylesayrs/fix-offloaded-saving' of https://github.com/v…
Sep 25, 2024
0222b40
Merge branch 'kylesayrs/fix-offloaded-saving' into kylesayrs/pickle-c…
Sep 25, 2024
e401736
apply quality
Sep 25, 2024
de84652
Remove failing tests
kylesayrs Sep 26, 2024
00791b6
explicitly set safe_serialization
Sep 27, 2024
4496220
Merge branch 'main' into kylesayrs/fix-offloaded-saving
kylesayrs Sep 28, 2024
11ddfb0
Merge remote-tracking branch 'origin' into kylesayrs/fix-offloaded-sa…
kylesayrs Oct 1, 2024
b4789aa
Merge branch 'kylesayrs/fix-offloaded-saving' into kylesayrs/pickle-c…
kylesayrs Oct 1, 2024
7a05e61
separate out gpu tests, apply style
kylesayrs Oct 1, 2024
cfd06ed
Merge branch 'kylesayrs/fix-offloaded-saving' into kylesayrs/pickle-c…
kylesayrs Oct 1, 2024
dc3d1f0
Merge branch 'main' into kylesayrs/pickle-checkpoints
kylesayrs Oct 4, 2024
214500e
Merge branch 'main' into kylesayrs/pickle-checkpoints
kylesayrs Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
super(KDModuleWrapper, self).__init__()

self.layer = layer
self._save_active = False
self._fsdp_active = fsdp_active
self.offload_output = offload_output
self.kd_transforms = transforms
Expand Down Expand Up @@ -88,16 +89,28 @@ def named_modules(
prefix: str = "",
remove_duplicate: bool = True,
):
# we want the full names of modules in two cases
# outside of saving, we want the full names of modules in two cases:
# 1. trainer initialization, so teacher is moved to the correct device. This is
# caught by the kd_enabled flag, which is set when the modifier is started
# 2. running in DataParallel (non-FSDP) mode so the replicate function can pick
# up the teacher.
if not self.kd_enabled or not self._fsdp_active:
return super().named_modules(
if self._save_active or (self.kd_enabled and self._fsdp_active):
return self.layer.named_modules(
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
)

return self.layer.named_modules(
return super().named_modules(
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
)

def prepare_for_save(self):
"""
Prepare model structure to be saved, specifically `self.named_modules`
"""
self._save_active = True

def finish_save(self):
"""
Finish saving model
"""
self._save_active = False
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self.teacher_model = teacher_model
self.wrappers = wrappers
self.kd_comparison = comparison
self._save_active = False
self._fsdp_active = fsdp_active
self.kd_enabled = False
self.register_buffer(self.KD_LAST_COMPARISON, torch.zeros(1, device="cpu"))
Expand Down Expand Up @@ -88,17 +89,17 @@ def named_modules(
prefix: str = "",
remove_duplicate: bool = True,
):
# we want the full names of modules in two cases
# outside of saving, we want the full names of modules in two cases:
# 1. trainer initialization, so teacher is moved to the correct device. This is
# caught by the kd_enabled flag, which is set when the modifier is started
# 2. running in DataParallel (non-FSDP) mode so the replicate function can pick
# up the teacher.
if not self.kd_enabled or not self._fsdp_active:
return super().named_modules(
if self._save_active or (self.kd_enabled and self._fsdp_active):
return self.student_model.named_modules(
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
)

return self.student_model.named_modules(
return super().named_modules(
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
)

Expand All @@ -109,6 +110,24 @@ def train(self, mode: bool = True):
self.student_model.train(mode)
return self

def prepare_for_save(self):
"""
Prepare model structure to be saved, specifically `self.named_modules`
"""
self._save_active = True
for student_wrapper, teacher_wrapper in self.wrappers.values():
student_wrapper.prepare_for_save()
teacher_wrapper.prepare_for_save()

def finish_save(self):
"""
Finish saving model
"""
self._save_active = False
for student_wrapper, teacher_wrapper in self.wrappers.values():
student_wrapper.finish_save()
teacher_wrapper.finish_save()

def __getattr__(self, name: str) -> Any:
try:
return super().__getattr__(name)
Expand Down
24 changes: 23 additions & 1 deletion src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
pre_initialize_structure,
)
from llmcompressor.metrics import LoggerManager
from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import (
KDModelWrapper,
)
from llmcompressor.pytorch.model_load.helpers import RECIPE_FILE_NAME, get_session_model
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
from llmcompressor.transformers.finetune.callbacks import (
Expand Down Expand Up @@ -341,13 +344,25 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
:param kwargs: keyword args to pass to super().train()
:return: the output from super.train()
"""

# lifecycle
checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)

# do not save checkpoints as compressed
original_save_compressed = self.args.save_compressed
self.args.save_compressed = False

# train with accelerator
self.accelerator.wait_for_everyone()
output = super().train(*args, **kwargs)
self.accelerator.wait_for_everyone()
self.finalize_session()

# restore original setting for saving final model
self.args.save_compressed = original_save_compressed

# lifecycle
self.finalize_session()
self.accelerator.wait_for_everyone()

# log model sparsity
Expand Down Expand Up @@ -430,6 +445,10 @@ def save_model(
if output_dir is None:
output_dir = self.args.output_dir

# knowledge distillation requires making wrappers transparent during
if isinstance(self.model, KDModelWrapper):
self.model.prepare_for_save()

if not is_fsdp_model(self.model):
self.model.save_pretrained(
output_dir,
Expand Down Expand Up @@ -467,6 +486,9 @@ def save_model(

self.accelerator.wait_for_everyone()

if isinstance(self.model, KDModelWrapper):
self.model.finish_save()

def maybe_log_model_sparsification(self):
"""
Log info on model sparsity and quantization if possible. Only print logs on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_oneshot_then_finetune(self):
concatenate_data = False
output_dir = self.output / "finetune_out"
splits = "train[:50%]"
max_steps = 50
max_steps = 25

with create_session():
train(
Expand All @@ -77,5 +77,23 @@ def test_oneshot_then_finetune(self):
max_steps=max_steps,
)

# test reloading checkpoint and final model
model = SparseAutoModelForCausalLM.from_pretrained(
output_dir, device_map="auto"
)
with create_session():
train(
model=model,
distill_teacher=distill_teacher,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
max_steps=max_steps,
resume_from_checkpoint=True, # use last checkpoint
)

def tearDown(self):
shutil.rmtree(self.output)
Loading