Skip to content

Commit f66e53d

Browse files
authored
fix: DPLSTM layers for FlatModel (#189)
1 parent 9303763 commit f66e53d

File tree

5 files changed

+26
-12
lines changed

5 files changed

+26
-12
lines changed

mostlyai/engine/_tabular/argn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,7 @@ def __init__(
871871
model_size: ModelSizeOrUnits,
872872
column_order: list[str] | None,
873873
device: torch.device,
874+
with_dp: bool = False,
874875
):
875876
super().__init__()
876877

@@ -892,6 +893,7 @@ def __init__(
892893
ctx_cardinalities=self.ctx_cardinalities,
893894
ctxseq_len_median=self.ctxseq_len_median,
894895
device=device,
896+
with_dp=with_dp,
895897
)
896898

897899
# sub column embeddings

mostlyai/engine/_tabular/common.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,18 @@
2020

2121
_LOG = logging.getLogger(__name__)
2222

23+
DPLSTM_SUFFIXES: tuple = ("ih.weight", "ih.bias", "hh.weight", "hh.bias")
2324

24-
def load_model_weights(model: torch.nn.Module, path: Path, device: torch.device):
25-
try:
26-
t00 = time.time()
27-
model.load_state_dict(torch.load(f=path, map_location=device, weights_only=True))
28-
_LOG.info(f"loaded model weights in {time.time() - t00:.2f}s")
29-
except Exception as e:
30-
_LOG.warning(f"failed to load model weights: {e}")
25+
26+
def load_model_weights(model: torch.nn.Module, path: Path, device: torch.device) -> None:
27+
t0 = time.time()
28+
incompatible_keys = model.load_state_dict(torch.load(f=path, map_location=device, weights_only=True), strict=False)
29+
missing_keys = incompatible_keys.missing_keys
30+
unexpected_keys = incompatible_keys.unexpected_keys
31+
# for DP-trained models, we expect extra keys from the DPLSTM layers (which is fine to ignore because we use standard LSTM layers during generation)
32+
# but if there're any other missing or unexpected keys, an error should be raised
33+
if len(missing_keys) > 0 or any(not k.endswith(DPLSTM_SUFFIXES) for k in unexpected_keys):
34+
raise RuntimeError(
35+
f"failed to load model weights due to incompatibility: {missing_keys = }, {unexpected_keys = }"
36+
)
37+
_LOG.info(f"loaded model weights in {time.time() - t0:.2f}s")

mostlyai/engine/_tabular/generation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -879,11 +879,14 @@ def generate(
879879
no_of_model_params = get_no_of_model_parameters(model)
880880
_LOG.info(f"{no_of_model_params=}")
881881

882-
load_model_weights(
883-
model=model,
884-
path=workspace.model_tabular_weights_path,
885-
device=device,
886-
)
882+
if workspace.model_tabular_weights_path.exists():
883+
load_model_weights(
884+
model=model,
885+
path=workspace.model_tabular_weights_path,
886+
device=device,
887+
)
888+
else:
889+
_LOG.warning("model weights not found; generating data with an untrained model")
887890

888891
model.to(device)
889892
model.eval()

mostlyai/engine/_tabular/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def train(
462462
model_size=model_size,
463463
column_order=trn_column_order,
464464
device=device,
465+
with_dp=with_dp,
465466
)
466467
_LOG.info(f"model class: {argn.__class__.__name__}")
467468

tests/end_to_end/test_tabular_sequential.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ def test_training_strategy(self, workspace_before_training, differential_privacy
590590
# it's actually a fresh training, so the progress will look different
591591
with pytest.raises(AssertionError):
592592
pd.testing.assert_frame_equal(progress_resume.iloc[:2], progress_resume_without_checkpoint.iloc[:2])
593+
generate(workspace_dir=workspace_before_training)
593594

594595

595596
def test_seed_generation(tmp_path):

0 commit comments

Comments
 (0)