lightning/tests/callbacks/test_early_stopping.py

179 lines
6.0 KiB
Python
Raw Normal View History

Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-06-29 01:36:46 +00:00
import pickle
import cloudpickle
import pytest
import torch
2020-08-06 14:58:51 +00:00
Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-06-29 01:36:46 +00:00
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
def test_resume_early_stopping_from_checkpoint(tmpdir):
"""
Prevent regressions to bugs:
https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
"""
class EarlyStoppingTestStore(EarlyStopping):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache the state for each epoch
self.saved_states = []
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(self.state_dict().copy())
class EarlyStoppingTestRestore(EarlyStopping):
def __init__(self, expected_state):
super().__init__()
self.expected_state = expected_state
def on_train_start(self, trainer, pl_module):
assert self.state_dict() == self.expected_state
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(save_top_k=1)
early_stop_callback = EarlyStoppingTestStore()
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
max_epochs=4,
)
trainer.fit(model)
checkpoint_filepath = checkpoint_callback.kth_best_model
# ensure state is persisted properly
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint['epoch'] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state
# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
)
new_trainer.fit(model)
def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
class EarlyStoppingTestInvocations(EarlyStopping):
def __init__(self, expected_count):
super().__init__()
self.count = 0
self.expected_count = expected_count
def on_validation_end(self, trainer, pl_module):
self.count += 1
def on_train_end(self, trainer, pl_module):
assert self.count == self.expected_count
model = EvalModelTemplate()
expected_count = 4
early_stop_callback = EarlyStoppingTestInvocations(expected_count)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=early_stop_callback,
val_check_interval=1.0,
max_epochs=expected_count,
)
trainer.fit(model)
@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [
([6, 5, 5, 5, 5, 5], 3, 4),
([6, 5, 4, 4, 3, 3], 1, 3),
([6, 5, 6, 5, 5, 5], 3, 4),
])
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
class ModelOverrideValidationReturn(EvalModelTemplate):
validation_return_values = torch.Tensor(loss_values)
count = 0
def validation_epoch_end(self, outputs):
loss = self.validation_return_values[self.count]
self.count += 1
return {"test_val_loss": loss}
model = ModelOverrideValidationReturn()
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=early_stop_callback,
val_check_interval=1.0,
num_sanity_val_steps=0,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
def test_pickling(tmpdir):
early_stopping = EarlyStopping()
early_stopping_pickled = pickle.dumps(early_stopping)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
early_stopping_pickled = cloudpickle.dumps(early_stopping)
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
def test_early_stopping_no_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""
class CurrentModel(EvalModelTemplate):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output.update({'my_train_metric': output['loss']}) # could be anything else
return output
model = CurrentModel()
model.validation_step = None
model.val_dataloader = None
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_batches=0.20,
max_epochs=2,
)
result = trainer.fit(model)
assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs
def test_early_stopping_functionality(tmpdir):
class CurrentModel(EvalModelTemplate):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
return {'val_loss': torch.tensor(val_loss)}
model = CurrentModel()
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
overfit_batches=0.20,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch == 5, 'early_stopping failed'