From cc9781a0adce9f81eb2e0c3f07551d3dad14bfee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 4 Oct 2020 23:36:47 +0200 Subject: [PATCH] Deprecate early_stop_callback Trainer argument (part 2) (#3845) * update tests with EarlyStopping default * imports * revert legacy tests * fix test * revert * revert --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/utilities/debugging.py | 7 +++---- tests/backends/test_ddp_spawn.py | 3 ++- tests/callbacks/test_early_stopping.py | 2 +- tests/models/test_gpu.py | 3 ++- tests/models/test_tpu.py | 5 +++-- 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 2330e69ecd..3177c9300e 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -181,7 +181,7 @@ class EarlyStopping(Callback): current = logs.get(self.monitor) # when in dev debugging - trainer.dev_debugger.track_early_stopping_history(current) + trainer.dev_debugger.track_early_stopping_history(self, current) if not isinstance(current, torch.Tensor): current = torch.tensor(current, device=pl_module.device) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 79a35cba47..242f3105d7 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -154,15 +154,14 @@ class InternalDebugger(object): self.pbar_added_metrics.append(metrics) @enabled_only - def track_early_stopping_history(self, current): - es = self.trainer.early_stop_callback + def track_early_stopping_history(self, callback, current): debug_dict = { 'epoch': self.trainer.current_epoch, 'global_step': self.trainer.global_step, 'rank': self.trainer.global_rank, 'current': current, - 'best': es.best_score, - 'patience': es.wait_count + 'best': callback.best_score, + 'patience': callback.wait_count } self.early_stopping_history.append(debug_dict) diff --git a/tests/backends/test_ddp_spawn.py b/tests/backends/test_ddp_spawn.py index 0c5db6b1a0..a1573b69ed 100644 --- a/tests/backends/test_ddp_spawn.py +++ b/tests/backends/test_ddp_spawn.py @@ -3,6 +3,7 @@ import torch import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils +from pytorch_lightning.callbacks import EarlyStopping from tests.base import EvalModelTemplate from pytorch_lightning.core import memory from pytorch_lightning.trainer import Trainer @@ -15,7 +16,7 @@ def test_multi_gpu_early_stop_ddp_spawn(tmpdir): trainer_options = dict( default_root_dir=tmpdir, - early_stop_callback=True, + callbacks=[EarlyStopping()], max_epochs=50, limit_train_batches=10, limit_val_batches=10, diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 98ff939ae6..8a1daaf695 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -77,7 +77,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir): expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, - early_stop_callback=True, + callbacks=[EarlyStopping()], val_check_interval=1.0, max_epochs=expected_count, ) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 8c2be4cabc..4b3d95c254 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -8,6 +8,7 @@ from torchtext.data import Batch, Dataset, Example, Field, LabelField import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core import memory from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -25,7 +26,7 @@ def test_multi_gpu_early_stop_dp(tmpdir): trainer_options = dict( default_root_dir=tmpdir, - early_stop_callback=True, + callbacks=[EarlyStopping()], max_epochs=50, limit_train_batches=10, limit_val_batches=10, diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index ef82cf4e46..cddc3db78a 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -6,6 +6,7 @@ from torch.utils.data import DataLoader import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.accelerators import TPUBackend +from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.base.datasets import TrialMNIST @@ -155,7 +156,7 @@ def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works""" model = EvalModelTemplate() trainer = Trainer( - early_stop_callback=True, + callbacks=[EarlyStopping()], default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=50, @@ -261,7 +262,7 @@ def test_result_obj_on_tpu(tmpdir): trainer_options = dict( default_root_dir=tmpdir, max_epochs=epochs, - early_stop_callback=True, + callbacks=[EarlyStopping()], row_log_interval=2, limit_train_batches=batches, weights_summary=None,