Fix for multiple callbacks (#6197)
* Fix for multiple callbacks * Add CHANGELOG.md * Remove old params * Skip tests on windows using ddp * Change name of the variable to not clash with should stop, which is separate * Apply suggestions from code review * Fix params Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
3ed8ef8af9
commit
dd2f5a0212
|
@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
|
||||
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
|
||||
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
|
||||
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
|
||||
|
||||
|
||||
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
|
||||
|
@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
|
||||
|
||||
|
||||
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
|
||||
|
||||
|
||||
## [1.2.1] - 2021-02-23
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -158,15 +158,12 @@ class EarlyStopping(Callback):
|
|||
if self.monitor_op(current - self.min_delta, self.best_score):
|
||||
self.best_score = current
|
||||
self.wait_count = 0
|
||||
should_stop = False
|
||||
else:
|
||||
self.wait_count += 1
|
||||
should_stop = self.wait_count >= self.patience
|
||||
|
||||
if bool(should_stop):
|
||||
if self.wait_count >= self.patience:
|
||||
self.stopped_epoch = trainer.current_epoch
|
||||
trainer.should_stop = True
|
||||
|
||||
# stop every ddp process if any world process decides to stop
|
||||
should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
|
||||
trainer.should_stop = should_stop
|
||||
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import cloudpickle
|
||||
|
@ -344,3 +345,57 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
|
|||
def test_early_stopping_mode_options():
|
||||
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
|
||||
EarlyStopping(mode="unknown_option")
|
||||
|
||||
|
||||
class EarlyStoppingModel(BoringModel):
|
||||
|
||||
def __init__(self, expected_end_epoch):
|
||||
super().__init__()
|
||||
self.expected_end_epoch = expected_end_epoch
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
losses = [8, 4, 2, 3, 4, 5, 8, 10]
|
||||
val_loss = losses[self.current_epoch]
|
||||
self.log('abc', torch.tensor(val_loss))
|
||||
self.log('cba', torch.tensor(0))
|
||||
|
||||
def on_train_end(self) -> None:
|
||||
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callbacks, expected_stop_epoch, accelerator, num_processes",
|
||||
[
|
||||
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
|
||||
([EarlyStopping(monitor='cba', patience=3),
|
||||
EarlyStopping(monitor='abc')], 3, None, 1),
|
||||
pytest.param([EarlyStopping(monitor='abc'),
|
||||
EarlyStopping(monitor='cba', patience=3)],
|
||||
3,
|
||||
'ddp_cpu',
|
||||
2,
|
||||
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
|
||||
pytest.param([EarlyStopping(monitor='cba', patience=3),
|
||||
EarlyStopping(monitor='abc')],
|
||||
3,
|
||||
'ddp_cpu',
|
||||
2,
|
||||
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
|
||||
],
|
||||
)
|
||||
def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir):
|
||||
"""
|
||||
Ensure when using multiple early stopping callbacks we stop if any signals we should stop.
|
||||
"""
|
||||
|
||||
model = EarlyStoppingModel(expected_stop_epoch)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=callbacks,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=20,
|
||||
accelerator=accelerator,
|
||||
num_processes=num_processes
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue