diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index b1034ef7d7..d109655243 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -1,38 +1,8 @@ -from pathlib import Path - -import pytest -import torch - -import tests.base.develop_utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate -def test_early_stopping_functionality(tmpdir): - - class CurrentModel(EvalModelTemplate): - def validation_epoch_end(self, outputs): - losses = [8, 4, 2, 3, 4, 5, 8, 10] - val_loss = losses[self.current_epoch] - return {'val_loss': torch.tensor(val_loss)} - - model = CurrentModel() - - trainer = Trainer( - default_root_dir=tmpdir, - early_stop_callback=True, - overfit_batches=0.20, - max_epochs=20, - ) - result = trainer.fit(model) - print(trainer.current_epoch) - - assert trainer.current_epoch == 5, 'early_stopping failed' - - def test_trainer_callback_system(tmpdir): """Test the callback system.""" @@ -262,86 +232,3 @@ def test_trainer_callback_system(tmpdir): assert not test_callback.on_validation_end_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_validation_batch_start_called - - -def test_early_stopping_no_val_step(tmpdir): - """Test that early stopping callback falls back to training metrics when no validation defined.""" - - class CurrentModel(EvalModelTemplate): - def training_step(self, *args, **kwargs): - output = super().training_step(*args, **kwargs) - output.update({'my_train_metric': output['loss']}) # could be anything else - return output - - model = CurrentModel() - model.validation_step = None - model.val_dataloader = None - - stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1) - trainer = Trainer( - default_root_dir=tmpdir, - early_stop_callback=stopping, - overfit_batches=0.20, - max_epochs=2, - ) - result = trainer.fit(model) - - assert result == 1, 'training failed to complete' - assert trainer.current_epoch < trainer.max_epochs - - -def test_pickling(tmpdir): - import pickle - early_stopping = EarlyStopping() - ckpt = ModelCheckpoint(tmpdir) - - early_stopping_pickled = pickle.dumps(early_stopping) - ckpt_pickled = pickle.dumps(ckpt) - - early_stopping_loaded = pickle.loads(early_stopping_pickled) - ckpt_loaded = pickle.loads(ckpt_pickled) - - assert vars(early_stopping) == vars(early_stopping_loaded) - assert vars(ckpt) == vars(ckpt_loaded) - - -@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) -def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ - tutils.reset_seed() - model = EvalModelTemplate() - - checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k) - - trainer = Trainer( - default_root_dir=tmpdir, - checkpoint_callback=checkpoint, - overfit_batches=0.20, - max_epochs=2, - ) - trainer.fit(model) - - # These should be different if the dirpath has be overridden - assert trainer.ckpt_path != trainer.default_root_dir - - -@pytest.mark.parametrize( - 'logger_version,expected', - [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], -) -def test_model_checkpoint_path(tmpdir, logger_version, expected): - """Test that "version_" prefix is only added when logger's version is an integer""" - tutils.reset_seed() - model = EvalModelTemplate() - logger = TensorBoardLogger(str(tmpdir), version=logger_version) - - trainer = Trainer( - default_root_dir=tmpdir, - overfit_batches=0.2, - max_epochs=2, - logger=logger, - ) - trainer.fit(model) - - ckpt_version = Path(trainer.ckpt_path).parent.name - assert ckpt_version == expected diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 2ba434af26..17ca3bb221 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -130,3 +130,49 @@ def test_pickling(tmpdir): early_stopping_pickled = cloudpickle.dumps(early_stopping) early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) + + +def test_early_stopping_no_val_step(tmpdir): + """Test that early stopping callback falls back to training metrics when no validation defined.""" + + class CurrentModel(EvalModelTemplate): + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + output.update({'my_train_metric': output['loss']}) # could be anything else + return output + + model = CurrentModel() + model.validation_step = None + model.val_dataloader = None + + stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=stopping, + overfit_batches=0.20, + max_epochs=2, + ) + result = trainer.fit(model) + + assert result == 1, 'training failed to complete' + assert trainer.current_epoch < trainer.max_epochs + + +def test_early_stopping_functionality(tmpdir): + + class CurrentModel(EvalModelTemplate): + def validation_epoch_end(self, outputs): + losses = [8, 4, 2, 3, 4, 5, 8, 10] + val_loss = losses[self.current_epoch] + return {'val_loss': torch.tensor(val_loss)} + + model = CurrentModel() + + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=True, + overfit_batches=0.20, + max_epochs=20, + ) + trainer.fit(model) + assert trainer.current_epoch == 5, 'early_stopping failed' diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 7257dc3874..bb575494c3 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -15,9 +15,7 @@ from tests.base import EvalModelTemplate @pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ - Test that None in checkpoint callback is valid and that chkp_path is set correctly - """ + """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ tutils.reset_seed() model = EvalModelTemplate() @@ -26,8 +24,8 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint, - overfit_pct=0.20, - max_epochs=(save_top_k + 2), + overfit_batches=0.20, + max_epochs=2, ) trainer.fit(model) @@ -47,8 +45,8 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): trainer = Trainer( default_root_dir=tmpdir, - overfit_pct=0.2, - max_epochs=5, + overfit_batches=0.2, + max_epochs=2, logger=logger, ) trainer.fit(model)