diff --git a/CHANGELOG.md b/CHANGELOG.md index 5032ab925c..3f8e5a5850 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,14 @@ vNext - - - -- Re-factored `MultiHeadDotProductAttention`'s call method signatur, by adding +- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding `inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic` to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389). - - -- +- Added `has_improved` field to EarlyStopping and changed the return signature of +`EarlyStopping.update` from returning a tuple to returning just the updated class. +See more details in [#3385](https://github.com/google/flax/pull/3385) - - - Use new typed PRNG keys throughout flax: this essentially involved changing @@ -27,7 +29,7 @@ to keyword arguments. See more details in [#3389](https://github.com/google/flax - - - -- +- NOTE: Remember to bump version number to 0.8.0 0.7.3 ----- diff --git a/flax/training/early_stopping.py b/flax/training/early_stopping.py index fd95fd761f..710a4c9a8e 100644 --- a/flax/training/early_stopping.py +++ b/flax/training/early_stopping.py @@ -30,7 +30,7 @@ class EarlyStopping(struct.PyTreeNode): rng, input_rng = jax.random.split(rng) optimizer, train_metrics = train_epoch( optimizer, train_ds, config.batch_size, epoch, input_rng) - _, early_stop = early_stop.update(train_metrics['loss']) + early_stop = early_stop.update(train_metrics['loss']) if early_stop.should_stop: print('Met early stopping criteria, breaking...') break @@ -43,6 +43,8 @@ class EarlyStopping(struct.PyTreeNode): patience_count: Number of steps since last improving update. should_stop: Whether the training loop should stop to avoid overfitting. + has_improved: Whether the metric has improved greater or + equal to the min_delta in the last `.update` call. """ min_delta: float = 0 @@ -50,28 +52,29 @@ class EarlyStopping(struct.PyTreeNode): best_metric: float = float('inf') patience_count: int = 0 should_stop: bool = False + has_improved: bool = False def reset(self): return self.replace( - best_metric=float('inf'), patience_count=0, should_stop=False + best_metric=float('inf'), patience_count=0, should_stop=False, has_improved=False ) def update(self, metric): """Update the state based on metric. Returns: - A pair (has_improved, early_stop), where `has_improved` is True when there - was an improvement greater than `min_delta` from the previous - `best_metric` and `early_stop` is the updated `EarlyStop` object. + The updated EarlyStopping class. The `.has_improved` attribute is True + when there was an improvement greater than `min_delta` from the previous + `best_metric`. """ if ( math.isinf(self.best_metric) or self.best_metric - metric > self.min_delta ): - return True, self.replace(best_metric=metric, patience_count=0) + return self.replace(best_metric=metric, patience_count=0, has_improved=True) else: should_stop = self.patience_count >= self.patience or self.should_stop - return False, self.replace( - patience_count=self.patience_count + 1, should_stop=should_stop + return self.replace( + patience_count=self.patience_count + 1, should_stop=should_stop, has_improved=False ) diff --git a/tests/early_stopping_test.py b/tests/early_stopping_test.py index ea6ce47ddc..f62ad065dd 100644 --- a/tests/early_stopping_test.py +++ b/tests/early_stopping_test.py @@ -32,8 +32,8 @@ def test_update(self): improve_steps = 0 for step in range(10): metric = 1.0 - did_improve, es = es.update(metric) - if not did_improve: + es = es.update(metric) + if not es.has_improved: improve_steps += 1 if es.should_stop: break @@ -48,7 +48,7 @@ def test_patience(self): patient_es = early_stopping.EarlyStopping(min_delta=0, patience=6) for step in range(10): metric = 1.0 - did_improve, es = es.update(metric) + es = es.update(metric) if es.should_stop: break @@ -56,7 +56,7 @@ def test_patience(self): for patient_step in range(10): metric = 1.0 - did_improve, patient_es = patient_es.update(metric) + patient_es = patient_es.update(metric) if patient_es.should_stop: break @@ -69,7 +69,7 @@ def test_delta(self): metric = 1.0 for step in range(100): metric -= 1e-4 - did_improve, es = es.update(metric) + es = es.update(metric) if es.should_stop: break @@ -78,7 +78,7 @@ def test_delta(self): metric = 1.0 for step in range(100): metric -= 1e-4 - did_improve, delta_es = delta_es.update(metric) + delta_es = delta_es.update(metric) if delta_es.should_stop: break @@ -99,8 +99,8 @@ def test_delta(self): improvement_steps = 0 for step in range(10): metric = metrics[step] - did_improve, delta_patient_es = delta_patient_es.update(metric) - if did_improve: + delta_patient_es = delta_patient_es.update(metric) + if delta_patient_es.has_improved: improvement_steps += 1 if delta_patient_es.should_stop: break