Skip to content

Commit

Permalink
reduce atol check for geneformer
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn committed Nov 12, 2024
1 parent 2f05f9c commit dcc75ec
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class TestGeneformerStopAndGo(stop_and_go.StopAndGoHarness):
limit_val_batches: int = 2
lr: float = 1e-4
precision: Literal["16-mixed", "bf16-mixed", "32"] = MODEL_PRECISION
train_output_atol: float = 2e-2

@override
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class StopAndGoHarness(ABC):
limit_val_batches: int
lr: float = 1e-4
precision: Literal["16-mixed", "bf16-mixed", "32"]
train_output_atol: float = 1e-3
other_output_atol: float = 1e-4

# class variables that will be setup in setUpClass
tempdir: tempfile.TemporaryDirectory
Expand Down Expand Up @@ -336,9 +338,9 @@ def test_stop_and_go_consistency(self, callback_type):
assert interrupted_callback.data, f"No data found for {callback_type}"

if callback_type == testing_callbacks.TrainOutputCallback:
atol = 1e-3
atol = self.train_output_atol
else:
atol = 1e-4
atol = self.other_output_atol

recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol)

Expand Down

0 comments on commit dcc75ec

Please sign in to comment.