Skip to content

Commit

Permalink
Fix weights loading
Browse files Browse the repository at this point in the history
  • Loading branch information
liambai committed Jan 30, 2025
1 parent ea0a693 commit 0ac89a0
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions interprot/endpoints/sae_inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,16 +393,19 @@ def load_models():
logger.info(f"Loading SAE model {sae_name}")
sae_model = SparseAutoencoder(plm_dim, sae_dim).to(device)
sae_weights = os.path.join(WEIGHTS_DIR, sae_checkpoint)
# Support different checkpoint formats

try:
sae_model.load_state_dict(torch.load(sae_weights))
sae_model.load_state_dict(torch.load(sae_weights, weights_only=False))
except Exception:
sae_model.load_state_dict(
{
k.replace("sae_model.", ""): v
for k, v in torch.load(sae_weights)["state_dict"].items()
}
)
try:
checkpoint = torch.load(sae_weights, weights_only=False)
sae_model.load_state_dict(
{k.replace("sae_model.", ""): v for k, v in checkpoint["state_dict"].items()}
)
except Exception as e:
logger.error(f"Failed to load SAE weights: {str(e)}")
raise

sea_name_to_info[sae_name] = {
"model": sae_model,
"plm_layer": plm_layer,
Expand Down

0 comments on commit 0ac89a0

Please sign in to comment.