Skip to content

Commit

Permalink
Merge pull request #3385 from chiamp:early_stopping
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577283866
  • Loading branch information
Flax Authors committed Oct 27, 2023
2 parents 9064108 + b620117 commit 36bbb41
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ vNext
-
-
-
- Re-factored `MultiHeadDotProductAttention`'s call method signatur, by adding
- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding
`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic`
to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389).
-
-
-
- Added `has_improved` field to EarlyStopping and changed the return signature of
`EarlyStopping.update` from returning a tuple to returning just the updated class.
See more details in [#3385](https://github.com/google/flax/pull/3385)
-
-
- Use new typed PRNG keys throughout flax: this essentially involved changing
Expand All @@ -27,7 +29,7 @@ to keyword arguments. See more details in [#3389](https://github.com/google/flax
-
-
-
-
- NOTE: Remember to bump version number to 0.8.0

0.7.3
-----
Expand Down
19 changes: 11 additions & 8 deletions flax/training/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class EarlyStopping(struct.PyTreeNode):
rng, input_rng = jax.random.split(rng)
optimizer, train_metrics = train_epoch(
optimizer, train_ds, config.batch_size, epoch, input_rng)
_, early_stop = early_stop.update(train_metrics['loss'])
early_stop = early_stop.update(train_metrics['loss'])
if early_stop.should_stop:
print('Met early stopping criteria, breaking...')
break
Expand All @@ -43,35 +43,38 @@ class EarlyStopping(struct.PyTreeNode):
patience_count: Number of steps since last improving update.
should_stop: Whether the training loop should stop to avoid
overfitting.
has_improved: Whether the metric has improved greater or
equal to the min_delta in the last `.update` call.
"""

min_delta: float = 0
patience: int = 0
best_metric: float = float('inf')
patience_count: int = 0
should_stop: bool = False
has_improved: bool = False

def reset(self):
return self.replace(
best_metric=float('inf'), patience_count=0, should_stop=False
best_metric=float('inf'), patience_count=0, should_stop=False, has_improved=False
)

def update(self, metric):
"""Update the state based on metric.
Returns:
A pair (has_improved, early_stop), where `has_improved` is True when there
was an improvement greater than `min_delta` from the previous
`best_metric` and `early_stop` is the updated `EarlyStop` object.
The updated EarlyStopping class. The `.has_improved` attribute is True
when there was an improvement greater than `min_delta` from the previous
`best_metric`.
"""

if (
math.isinf(self.best_metric)
or self.best_metric - metric > self.min_delta
):
return True, self.replace(best_metric=metric, patience_count=0)
return self.replace(best_metric=metric, patience_count=0, has_improved=True)
else:
should_stop = self.patience_count >= self.patience or self.should_stop
return False, self.replace(
patience_count=self.patience_count + 1, should_stop=should_stop
return self.replace(
patience_count=self.patience_count + 1, should_stop=should_stop, has_improved=False
)
16 changes: 8 additions & 8 deletions tests/early_stopping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_update(self):
improve_steps = 0
for step in range(10):
metric = 1.0
did_improve, es = es.update(metric)
if not did_improve:
es = es.update(metric)
if not es.has_improved:
improve_steps += 1
if es.should_stop:
break
Expand All @@ -48,15 +48,15 @@ def test_patience(self):
patient_es = early_stopping.EarlyStopping(min_delta=0, patience=6)
for step in range(10):
metric = 1.0
did_improve, es = es.update(metric)
es = es.update(metric)
if es.should_stop:
break

self.assertEqual(step, 1)

for patient_step in range(10):
metric = 1.0
did_improve, patient_es = patient_es.update(metric)
patient_es = patient_es.update(metric)
if patient_es.should_stop:
break

Expand All @@ -69,7 +69,7 @@ def test_delta(self):
metric = 1.0
for step in range(100):
metric -= 1e-4
did_improve, es = es.update(metric)
es = es.update(metric)
if es.should_stop:
break

Expand All @@ -78,7 +78,7 @@ def test_delta(self):
metric = 1.0
for step in range(100):
metric -= 1e-4
did_improve, delta_es = delta_es.update(metric)
delta_es = delta_es.update(metric)
if delta_es.should_stop:
break

Expand All @@ -99,8 +99,8 @@ def test_delta(self):
improvement_steps = 0
for step in range(10):
metric = metrics[step]
did_improve, delta_patient_es = delta_patient_es.update(metric)
if did_improve:
delta_patient_es = delta_patient_es.update(metric)
if delta_patient_es.has_improved:
improvement_steps += 1
if delta_patient_es.should_stop:
break
Expand Down

0 comments on commit 36bbb41

Please sign in to comment.