From d9bc95f83e163f1ef0e64012ad086d4448410817 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 3 Oct 2020 12:33:29 -0400 Subject: [PATCH] ref: bug fix with logging val epoch end + monitor (#3812) * ref: fix metric err * ref: fix metric err * ref: fix metric err * ref: merge * ref: merge * ref: merge * ref: merge * ref: decoupled ddp2 * ref: decoupled ddp2 * ref: decoupled ddp2 * ref: decoupled ddp2 * ref: decoupled ddp2 * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix --- pytorch_lightning/callbacks/early_stopping.py | 4 +- .../trainer/connectors/logger_connector.py | 5 +- .../trainer/connectors/optimizer_connector.py | 10 ++- pytorch_lightning/trainer/evaluation_loop.py | 4 + pytorch_lightning/trainer/optimizers.py | 18 ++-- pytorch_lightning/trainer/training_loop.py | 10 +-- tests/base/__init__.py | 1 + tests/base/model_valid_epoch_ends.py | 2 +- tests/base/simple_model.py | 85 +++++++++++++++++++ tests/callbacks/test_early_stopping.py | 6 +- tests/callbacks/test_model_checkpoint.py | 30 +++---- tests/core/test_datamodules.py | 2 + tests/loggers/test_all.py | 22 +++-- tests/models/test_amp.py | 2 + tests/models/test_restore.py | 2 +- tests/trainer/data_flow/__init__.py | 0 .../test_eval_loop_flow_1_0.py | 0 .../test_train_loop_flow_dict_1_0.py | 0 .../test_train_loop_flow_scalar_1_0.py | 0 .../__init__.py | 0 .../test_eval_loop_dict_return.py | 8 +- .../test_trainer_steps_dict_return.py | 0 .../test_trainer_steps_result_return.py | 0 .../test_trainer_steps_scalar_return.py | 0 .../test_validation_steps_result_return.py | 0 tests/trainer/logging/__init__.py | 0 .../test_eval_loop_logging_1_0.py | 15 ++++ .../test_train_loop_logging_1_0.py | 0 tests/trainer/test_optimizers.py | 35 ++++++-- tests/trainer/test_trainer.py | 7 +- 30 files changed, 212 insertions(+), 56 deletions(-) create mode 100644 tests/base/simple_model.py create mode 100644 tests/trainer/data_flow/__init__.py rename tests/trainer/{ => data_flow}/test_eval_loop_flow_1_0.py (100%) rename tests/trainer/{ => data_flow}/test_train_loop_flow_dict_1_0.py (100%) rename tests/trainer/{ => data_flow}/test_train_loop_flow_scalar_1_0.py (100%) create mode 100644 tests/trainer/legacy_deprecate_flow_log_tests/__init__.py rename tests/trainer/{ => legacy_deprecate_flow_log_tests}/test_eval_loop_dict_return.py (97%) rename tests/trainer/{ => legacy_deprecate_flow_log_tests}/test_trainer_steps_dict_return.py (100%) rename tests/trainer/{ => legacy_deprecate_flow_log_tests}/test_trainer_steps_result_return.py (100%) rename tests/trainer/{ => legacy_deprecate_flow_log_tests}/test_trainer_steps_scalar_return.py (100%) rename tests/trainer/{ => legacy_deprecate_flow_log_tests}/test_validation_steps_result_return.py (100%) create mode 100644 tests/trainer/logging/__init__.py rename tests/trainer/{ => logging}/test_eval_loop_logging_1_0.py (90%) rename tests/trainer/{ => logging}/test_train_loop_logging_1_0.py (100%) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 1a615633cd..eeb344ae64 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -106,9 +106,9 @@ class EarlyStopping(Callback): def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) + error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' - f' which is not available. Either add `{self.monitor}` to the return of' - ' `validation_epoch_end` or modify your `EarlyStopping` callback to use any of the' + f' which is not available. Pass in or modify your `EarlyStopping` callback to use any of the' f' following: `{"`, `".join(list(logs.keys()))}`') if monitor_val is None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index d758ee7808..318fad9c47 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -171,7 +171,10 @@ class LoggerConnector: return result def _track_callback_metrics(self, eval_results, using_eval_result): - if len(eval_results) > 0 and eval_results[0] is None: + if ( + len(eval_results) > 0 and + (eval_results[0] is None or not isinstance(eval_results[0], Result)) + ): return if using_eval_result: diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 092245dbe4..32a554a5a6 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -43,7 +43,13 @@ class OptimizerConnector: if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0: # If instance of ReduceLROnPlateau, we need to pass validation loss if lr_scheduler['reduce_on_plateau']: - monitor_key = lr_scheduler['monitor'] + try: + monitor_key = lr_scheduler['monitor'] + except KeyError as e: + m = "ReduceLROnPlateau requires returning a dict from configure_optimizers with the keyword " \ + "monitor=. For example:" \ + "return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'your_loss'}" + raise MisconfigurationException(m) if monitor_metrics is not None: monitor_val = monitor_metrics.get(monitor_key) @@ -54,7 +60,7 @@ class OptimizerConnector: avail_metrics = ','.join(list(self.trainer.logger_connector.callback_metrics.keys())) raise MisconfigurationException( f'ReduceLROnPlateau conditioned on metric {monitor_key}' - f' which is not available. Available metrics are: {avail_metrics}.' + f' which is not available. Available metrics are: [{avail_metrics}].' ' Condition can be set using `monitor` key in lr scheduler dict' ) # update LR diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index b1a2b21b03..38eb5b9465 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -228,6 +228,10 @@ class EvaluationLoop(object): if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) + result = model._results + if len(result) > 0 and eval_results is None: + eval_results = result.get_epoch_log_metrics() + if not isinstance(eval_results, list): eval_results = [eval_results] diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 6aaa0dc663..1b3d4b86ea 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -50,9 +50,10 @@ class TrainerOptimizersMixin(ABC): # single dictionary elif isinstance(optim_conf, dict): optimizer = optim_conf["optimizer"] + monitor = optim_conf.get('monitor', None) lr_scheduler = optim_conf.get("lr_scheduler", []) if lr_scheduler: - lr_schedulers = self.configure_schedulers([lr_scheduler]) + lr_schedulers = self.configure_schedulers([lr_scheduler], monitor) else: lr_schedulers = [] return [optimizer], lr_schedulers, [] @@ -94,13 +95,18 @@ class TrainerOptimizersMixin(ABC): ' a list of `torch.optim.lr_scheduler`' ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') - def configure_schedulers(self, schedulers: list): + def configure_schedulers(self, schedulers: list, monitor: str = None): # Convert each scheduler into dict structure with relevant information lr_schedulers = [] - default_config = {'interval': 'epoch', # default every epoch - 'frequency': 1, # default every epoch/batch - 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler - 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau + default_config = { + 'interval': 'epoch', # default every epoch + 'frequency': 1, # default every epoch/batch + 'reduce_on_plateau': False + } # most often not ReduceLROnPlateau scheduler + + if monitor is not None: + default_config['monitor'] = monitor + for scheduler in schedulers: if isinstance(scheduler, dict): if 'scheduler' not in scheduler: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 291bba0ad9..5b16ea5b1f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -531,6 +531,11 @@ class TrainLoop: # TODO: add outputs to batches self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) + # ----------------------------------------- + # SAVE METRICS TO LOGGERS + # ----------------------------------------- + self.trainer.logger_connector.log_train_step_metrics(batch_output) + # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- @@ -538,11 +543,6 @@ class TrainLoop: if should_check_val: self.trainer.run_evaluation(test_mode=False) - # ----------------------------------------- - # SAVE METRICS TO LOGGERS - # ----------------------------------------- - self.trainer.logger_connector.log_train_step_metrics(batch_output) - # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- diff --git a/tests/base/__init__.py b/tests/base/__init__.py index fe2b327f41..6de94720eb 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -2,3 +2,4 @@ from tests.base.datasets import TrialMNIST from tests.base.model_template import EvalModelTemplate, GenericEvalModelTemplate +from tests.base.simple_model import SimpleModule diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py index c2a079113a..a2b8e62480 100644 --- a/tests/base/model_valid_epoch_ends.py +++ b/tests/base/model_valid_epoch_ends.py @@ -51,7 +51,7 @@ class ValidationEpochEndVariations(ABC): val_loss_mean = val_loss_mean.item() val_acc_mean = val_acc_mean.item() - metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean} + metrics_dict = {'early_stop_on': val_loss_mean, 'val_acc': val_acc_mean} results = {'progress_bar': metrics_dict, 'log': metrics_dict} return results diff --git a/tests/base/simple_model.py b/tests/base/simple_model.py new file mode 100644 index 0000000000..810da7e731 --- /dev/null +++ b/tests/base/simple_model.py @@ -0,0 +1,85 @@ +import torch +from pytorch_lightning import LightningModule +from torch.utils.data import Dataset +from typing import Optional + + +class RandomDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class SimpleModule(LightningModule): + def __init__(self, epoch_min_loss_override: Optional[int] = None): + """LightningModule for testing purposes + Args: + epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum + validation loss for testing purposes (zero based). If None this is ignored. Defaults to None. + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + self.epoch_min_loss_override = epoch_min_loss_override + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def training_step(self, batch, batch_idx): + output = self.forward(batch) + loss = self.loss(batch, output) + return {"output": output, "loss": loss, "checkpoint_on": loss} + + def validation_step(self, batch, batch_idx): + output = self.forward(batch) + loss = self.loss(batch, output) + return {"output": output, "loss": loss, "checkpoint_on": loss} + + def test_step(self, batch, batch_idx): + output = self.forward(batch) + loss = self.loss(batch, output) + return {"output": output, "loss": loss} + + def training_epoch_end(self, outputs) -> None: + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() + self.log("avg_loss", avg_loss) + + def validation_epoch_end(self, outputs) -> None: + avg_val_loss = torch.stack( + [torch.randn(1, requires_grad=True) for _ in outputs] + ).mean() + # For testing purposes allow a nominated epoch to have a low loss + if self.current_epoch == self.epoch_min_loss_override: + avg_val_loss -= 1e10 + + self.log("avg_val_loss", avg_val_loss) + self.log("checkpoint_on", avg_val_loss) + + def test_epoch_end(self, outputs) -> None: + avg_loss = torch.stack( + [torch.randn(1, requires_grad=True) for _ in outputs] + ).mean() + self.log("test_loss", avg_loss) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cbc3e28d63..98ff939ae6 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -37,7 +37,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): """ model = EvalModelTemplate() - checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1) + checkpoint_callback = ModelCheckpoint(monitor="early_stop_on", save_top_k=1) early_stop_callback = EarlyStoppingTestRestore() trainer = Trainer( default_root_dir=tmpdir, @@ -159,13 +159,13 @@ def test_early_stopping_functionality(tmpdir): def validation_epoch_end(self, outputs): losses = [8, 4, 2, 3, 4, 5, 8, 10] val_loss = losses[self.current_epoch] - return {'val_loss': torch.tensor(val_loss)} + self.log('abc', torch.tensor(val_loss)) model = CurrentModel() trainer = Trainer( default_root_dir=tmpdir, - early_stop_callback=True, + early_stop_callback=EarlyStopping(monitor='abc'), overfit_batches=0.20, max_epochs=20, ) diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index fd39490318..b3b8204166 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -25,7 +25,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): tutils.reset_seed() model = EvalModelTemplate() - checkpoint = ModelCheckpoint(monitor='val_loss', filepath=None, save_top_k=save_top_k) + checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=None, save_top_k=save_top_k) trainer = Trainer( default_root_dir=tmpdir, @@ -45,7 +45,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k): tutils.reset_seed() model = EvalModelTemplate() - checkpoint = ModelCheckpoint(filepath=tmpdir, monitor='val_loss', save_top_k=save_top_k) + checkpoint = ModelCheckpoint(filepath=tmpdir, monitor='early_stop_on', save_top_k=save_top_k) trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2) trainer.fit(model) @@ -124,7 +124,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" model = EvalModelTemplate() num_epochs = 4 - model_checkpoint = ModelCheckpointTestInvocations(monitor='val_loss', expected_count=num_epochs, save_top_k=-1) + model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1) trainer = Trainer( distributed_backend="ddp_cpu", num_processes=2, @@ -156,23 +156,23 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no filepath set - ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=None).format_checkpoint_name(3, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, {}) assert ckpt_name == 'epoch=3.ckpt' - ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='').format_checkpoint_name(5, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, {}) assert ckpt_name == 'epoch=5.ckpt' # CWD - ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='.').format_checkpoint_name(3, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {}) assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt' # dir does not exist so it is used as filename filepath = tmpdir / 'dir' - ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) assert ckpt_name == tmpdir / 'test-dir.ckpt' # now, dir exists os.mkdir(filepath) - ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) assert ckpt_name == filepath / 'test-epoch=3.ckpt' # with ver - ckpt_name = ModelCheckpoint(monitor='val_loss', + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3) assert ckpt_name == tmpdir / 'test-name-v3.ckpt' @@ -182,7 +182,7 @@ def test_model_checkpoint_save_last(tmpdir): model = EvalModelTemplate() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' - model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=-1, save_last=True) + model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, @@ -304,7 +304,7 @@ def test_model_checkpoint_topk_all(tmpdir): seed_everything(1000) epochs = 2 model = EvalModelTemplate() - checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="val_loss", save_top_k=-1) + checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="early_stop_on", save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, @@ -330,7 +330,7 @@ def test_ckpt_metric_names(tmpdir): progress_bar_refresh_rate=0, limit_train_batches=0.01, limit_val_batches=0.01, - checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"), + checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir + "/{val_loss:.2f}"), ) trainer.fit(model) @@ -390,7 +390,7 @@ def test_ckpt_metric_names_results(tmpdir): progress_bar_refresh_rate=0, limit_train_batches=0.01, limit_val_batches=0.01, - checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"), + checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir + "/{val_loss:.2f}"), ) trainer.fit(model) @@ -413,7 +413,7 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v model.validation_step = None trainer = Trainer( default_root_dir=tmpdir, - checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=0, save_last=save_last), + checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=0, save_last=save_last), max_epochs=max_epochs, ) trainer.fit(model) @@ -426,7 +426,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint( - monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True + monitor='early_stop_on', filepath=tmpdir, save_top_k=num_epochs, save_last=True ) trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 5325ca828e..4dc9412eb5 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -11,6 +11,7 @@ from tests.base.datamodules import TrialMNISTDataModule from tests.base.develop_utils import reset_seed from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.accelerators.gpu_backend import GPUBackend +from pytorch_lightning.callbacks import ModelCheckpoint def test_can_prepare_data(tmpdir): @@ -226,6 +227,7 @@ def test_dm_checkpoint_save(tmpdir): default_root_dir=tmpdir, max_epochs=3, weights_summary=None, + checkpoint_callback=ModelCheckpoint(monitor='early_stop_on') ) # fit model diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 0f60194944..5e5c7afbd8 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -82,15 +82,21 @@ def test_loggers_fit_test(wandb, tmpdir, monkeypatch, logger_class): log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history] if logger_class == TensorBoardLogger: - assert log_metric_names == [(0, ['hp_metric']), - (0, ['epoch', 'val_acc', 'val_loss']), - (0, ['epoch', 'train_some_val']), - (0, ['hp_metric']), - (1, ['epoch', 'test_acc', 'test_loss'])] + expected = [ + (0, ['hp_metric']), + (0, ['epoch', 'train_some_val']), + (0, ['early_stop_on', 'epoch', 'val_acc']), + (0, ['hp_metric']), + (1, ['epoch', 'test_acc', 'test_loss']) + ] + assert log_metric_names == expected else: - assert log_metric_names == [(0, ['epoch', 'val_acc', 'val_loss']), - (0, ['epoch', 'train_some_val']), - (1, ['epoch', 'test_acc', 'test_loss'])] + expected = [ + (0, ['epoch', 'train_some_val']), + (0, ['early_stop_on', 'epoch', 'val_acc']), + (1, ['epoch', 'test_acc', 'test_loss']) + ] + assert log_metric_names == expected @pytest.mark.parametrize("logger_class", [ diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 5284d4ce58..997f947996 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -10,6 +10,7 @@ from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from pytorch_lightning.utilities import APEX_AVAILABLE @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @@ -170,6 +171,7 @@ def test_amp_without_apex(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex") def test_amp_with_apex(tmpdir): """Check calling apex scaling in training.""" os.environ['PL_DEV_DEBUG'] = '1' diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 3cb40fcf02..438589bd71 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -162,7 +162,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template): max_epochs=2, limit_train_batches=0.4, limit_val_batches=0.2, - checkpoint_callback=ModelCheckpoint(tmpdir, monitor='val_loss', save_top_k=-1), + checkpoint_callback=ModelCheckpoint(tmpdir, monitor='early_stop_on', save_top_k=-1), default_root_dir=tmpdir, ) diff --git a/tests/trainer/data_flow/__init__.py b/tests/trainer/data_flow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/trainer/test_eval_loop_flow_1_0.py b/tests/trainer/data_flow/test_eval_loop_flow_1_0.py similarity index 100% rename from tests/trainer/test_eval_loop_flow_1_0.py rename to tests/trainer/data_flow/test_eval_loop_flow_1_0.py diff --git a/tests/trainer/test_train_loop_flow_dict_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py similarity index 100% rename from tests/trainer/test_train_loop_flow_dict_1_0.py rename to tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py diff --git a/tests/trainer/test_train_loop_flow_scalar_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py similarity index 100% rename from tests/trainer/test_train_loop_flow_scalar_1_0.py rename to tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/__init__.py b/tests/trainer/legacy_deprecate_flow_log_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/trainer/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py similarity index 97% rename from tests/trainer/test_eval_loop_dict_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 74aa8ce1fc..6071ab0f45 100644 --- a/tests/trainer/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir): assert k in eval_results[1] # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [9, 10] + assert len(trainer.logger_connector.callback_metrics) in [7, 8] # make sure correct steps were called assert model.validation_step_called @@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir): assert k in eval_results[1] # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [10, 11] + assert len(trainer.logger_connector.callback_metrics) in [8, 9] # make sure correct steps were called assert model.validation_step_called @@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir): assert k in eval_results # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [10, 11] + assert len(trainer.logger_connector.callback_metrics) in [8, 9] # make sure correct steps were called assert model.validation_step_called @@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir): assert k in eval_results # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) in [11, 12] + assert len(trainer.logger_connector.callback_metrics) in [9, 10] # make sure correct steps were called assert model.validation_step_called diff --git a/tests/trainer/test_trainer_steps_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py similarity index 100% rename from tests/trainer/test_trainer_steps_dict_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py similarity index 100% rename from tests/trainer/test_trainer_steps_result_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py diff --git a/tests/trainer/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py similarity index 100% rename from tests/trainer/test_trainer_steps_scalar_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py diff --git a/tests/trainer/test_validation_steps_result_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py similarity index 100% rename from tests/trainer/test_validation_steps_result_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py diff --git a/tests/trainer/logging/__init__.py b/tests/trainer/logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/trainer/test_eval_loop_logging_1_0.py b/tests/trainer/logging/test_eval_loop_logging_1_0.py similarity index 90% rename from tests/trainer/test_eval_loop_logging_1_0.py rename to tests/trainer/logging/test_eval_loop_logging_1_0.py index aa7ff6580e..9b6b50898e 100644 --- a/tests/trainer/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging/test_eval_loop_logging_1_0.py @@ -2,7 +2,9 @@ Tests to ensure that the training loop works with a dict (1.0) """ from pytorch_lightning import Trainer +from pytorch_lightning import callbacks from tests.base.deterministic_model import DeterministicModel +from tests.base import SimpleModule import os import torch @@ -141,5 +143,18 @@ def test__validation_step__step_end__epoch_end__log(tmpdir): # we don't want to enable val metrics during steps because it is not something that users should do callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'epoch_b', 'epoch_d', 'epoch_f', 'f', 'g', 'step_b'} assert expected_cb_metrics == callback_metrics + + +def test_monitor_val_epoch_end(tmpdir): + epoch_min_loss_override = 0 + model = SimpleModule() + checkpoint_callback = callbacks.ModelCheckpoint(save_top_k=1, monitor="avg_val_loss") + trainer = Trainer( + max_epochs=epoch_min_loss_override + 2, + logger=False, + checkpoint_callback=checkpoint_callback, + ) + trainer.fit(model) diff --git a/tests/trainer/test_train_loop_logging_1_0.py b/tests/trainer/logging/test_train_loop_logging_1_0.py similarity index 100% rename from tests/trainer/test_train_loop_logging_1_0.py rename to tests/trainer/logging/test_train_loop_logging_1_0.py diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 6b87487729..33f0ca8ced 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -3,6 +3,7 @@ import torch from pytorch_lightning import Trainer from tests.base import EvalModelTemplate +from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_optimizer_with_scheduling(tmpdir): @@ -111,12 +112,36 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir): 'lr for optimizer 2 not adjusted correctly' -def test_reduce_lr_on_plateau_scheduling(tmpdir): +def test_reduce_lr_on_plateau_scheduling_missing_monitor(tmpdir): hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) model.configure_optimizers = model.configure_optimizers__reduce_lr_on_plateau + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + + m = '.*ReduceLROnPlateau requires returning a dict from configure_optimizers.*' + with pytest.raises(MisconfigurationException, match=m): + trainer.fit(model) + + +def test_reduce_lr_on_plateau_scheduling(tmpdir): + hparams = EvalModelTemplate.get_default_hparams() + class TestModel(EvalModelTemplate): + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'monitor': 'early_stop_on'} + + model = TestModel(**hparams) + # fit model trainer = Trainer( default_root_dir=tmpdir, @@ -128,7 +153,7 @@ def test_reduce_lr_on_plateau_scheduling(tmpdir): assert results == 1 assert trainer.lr_schedulers[0] == \ - dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss', + dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='early_stop_on', interval='epoch', frequency=1, reduce_on_plateau=True), \ 'lr schduler was not correctly converted to dict' @@ -167,7 +192,7 @@ def test_optimizer_return_options(): assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 assert optim[0] == opt_a assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') + frequency=1, reduce_on_plateau=False) # opt single dictionary model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} @@ -175,7 +200,7 @@ def test_optimizer_return_options(): assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 assert optim[0] == opt_a assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') + frequency=1, reduce_on_plateau=False) # opt multiple dictionaries with frequencies model.configure_optimizers = lambda: ( @@ -186,7 +211,7 @@ def test_optimizer_return_options(): assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2 assert optim[0] == opt_a assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') + frequency=1, reduce_on_plateau=False) assert freq == [1, 5] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d27a701cfa..e0049e7aad 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -432,7 +432,7 @@ def test_model_checkpoint_only_weights(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True), + checkpoint_callback=ModelCheckpoint(tmpdir, monitor='early_stop_on', save_weights_only=True), ) # fit model result = trainer.fit(model) @@ -508,7 +508,7 @@ def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_serve max_epochs=2, limit_train_batches=0.65, limit_val_batches=1, - checkpoint_callback=ModelCheckpoint(tmpdir, monitor='val_loss', save_top_k=-1), + checkpoint_callback=ModelCheckpoint(tmpdir, monitor='early_stop_on', save_top_k=-1), default_root_dir=tmpdir, early_stop_callback=False, val_check_interval=1., @@ -665,7 +665,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): max_epochs=2, progress_bar_refresh_rate=0, default_root_dir=tmpdir, - checkpoint_callback=ModelCheckpoint(monitor='val_loss', save_top_k=save_top_k), + checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', save_top_k=save_top_k), ) trainer.fit(model) if ckpt_path == 'best': @@ -898,6 +898,7 @@ def test_gradient_clipping_fp16(tmpdir): trainer.fit(model) + def test_gpu_choice(tmpdir): trainer_options = dict( default_root_dir=tmpdir,