Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored Oct 22, 2024
1 parent d155b70 commit 74a79cc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/llamafactory/train/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is False
else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is True


def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
Expand Down

0 comments on commit 74a79cc

Please sign in to comment.