Deprecate early_stop_callback Trainer argument (part 2) (#3845)
* update tests with EarlyStopping default * imports * revert legacy tests * fix test * revert * revert
This commit is contained in:
parent
6723b924f8
commit
cc9781a0ad
|
@ -181,7 +181,7 @@ class EarlyStopping(Callback):
|
|||
current = logs.get(self.monitor)
|
||||
|
||||
# when in dev debugging
|
||||
trainer.dev_debugger.track_early_stopping_history(current)
|
||||
trainer.dev_debugger.track_early_stopping_history(self, current)
|
||||
|
||||
if not isinstance(current, torch.Tensor):
|
||||
current = torch.tensor(current, device=pl_module.device)
|
||||
|
|
|
@ -154,15 +154,14 @@ class InternalDebugger(object):
|
|||
self.pbar_added_metrics.append(metrics)
|
||||
|
||||
@enabled_only
|
||||
def track_early_stopping_history(self, current):
|
||||
es = self.trainer.early_stop_callback
|
||||
def track_early_stopping_history(self, callback, current):
|
||||
debug_dict = {
|
||||
'epoch': self.trainer.current_epoch,
|
||||
'global_step': self.trainer.global_step,
|
||||
'rank': self.trainer.global_rank,
|
||||
'current': current,
|
||||
'best': es.best_score,
|
||||
'patience': es.wait_count
|
||||
'best': callback.best_score,
|
||||
'patience': callback.wait_count
|
||||
}
|
||||
self.early_stopping_history.append(debug_dict)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
|
||||
import tests.base.develop_pipelines as tpipes
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from tests.base import EvalModelTemplate
|
||||
from pytorch_lightning.core import memory
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
|
@ -15,7 +16,7 @@ def test_multi_gpu_early_stop_ddp_spawn(tmpdir):
|
|||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=True,
|
||||
callbacks=[EarlyStopping()],
|
||||
max_epochs=50,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=10,
|
||||
|
|
|
@ -77,7 +77,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
|
|||
expected_count = 4
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=True,
|
||||
callbacks=[EarlyStopping()],
|
||||
val_check_interval=1.0,
|
||||
max_epochs=expected_count,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from torchtext.data import Batch, Dataset, Example, Field, LabelField
|
|||
import tests.base.develop_pipelines as tpipes
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.core import memory
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -25,7 +26,7 @@ def test_multi_gpu_early_stop_dp(tmpdir):
|
|||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=True,
|
||||
callbacks=[EarlyStopping()],
|
||||
max_epochs=50,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=10,
|
||||
|
|
|
@ -6,6 +6,7 @@ from torch.utils.data import DataLoader
|
|||
import tests.base.develop_pipelines as tpipes
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.accelerators import TPUBackend
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base.datasets import TrialMNIST
|
||||
|
@ -155,7 +156,7 @@ def test_model_tpu_early_stop(tmpdir):
|
|||
"""Test if single TPU core training works"""
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(
|
||||
early_stop_callback=True,
|
||||
callbacks=[EarlyStopping()],
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=50,
|
||||
|
@ -261,7 +262,7 @@ def test_result_obj_on_tpu(tmpdir):
|
|||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=epochs,
|
||||
early_stop_callback=True,
|
||||
callbacks=[EarlyStopping()],
|
||||
row_log_interval=2,
|
||||
limit_train_batches=batches,
|
||||
weights_summary=None,
|
||||
|
|
Loading…
Reference in New Issue