Fix for multiple callbacks (#6197)

* Fix for multiple callbacks

* Add CHANGELOG.md

* Remove old params

* Skip tests on windows using ddp

* Change name of the variable to not clash with should stop, which is separate

* Apply suggestions from code review

* Fix params

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Sean Naren 2021-02-25 15:44:55 +00:00 committed by GitHub
parent 3ed8ef8af9
commit dd2f5a0212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 6 deletions

View File

@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
## [1.2.1] - 2021-02-23
### Fixed

View File

@ -158,15 +158,12 @@ class EarlyStopping(Callback):
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
should_stop = False
else:
self.wait_count += 1
should_stop = self.wait_count >= self.patience
if bool(should_stop):
if self.wait_count >= self.patience:
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
trainer.should_stop = should_stop
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import pickle
import sys
from unittest import mock
import cloudpickle
@ -344,3 +345,57 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
EarlyStopping(mode="unknown_option")
class EarlyStoppingModel(BoringModel):
def __init__(self, expected_end_epoch):
super().__init__()
self.expected_end_epoch = expected_end_epoch
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', torch.tensor(val_loss))
self.log('cba', torch.tensor(0))
def on_train_end(self) -> None:
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'
@pytest.mark.parametrize(
"callbacks, expected_stop_epoch, accelerator, num_processes",
[
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')], 3, None, 1),
pytest.param([EarlyStopping(monitor='abc'),
EarlyStopping(monitor='cba', patience=3)],
3,
'ddp_cpu',
2,
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
pytest.param([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')],
3,
'ddp_cpu',
2,
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
],
)
def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir):
"""
Ensure when using multiple early stopping callbacks we stop if any signals we should stop.
"""
model = EarlyStoppingModel(expected_stop_epoch)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
overfit_batches=0.20,
max_epochs=20,
accelerator=accelerator,
num_processes=num_processes
)
trainer.fit(model)