From a34eb9e169622fe91fdf4d98560b65b2f2b5c8d0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 25 May 2020 13:43:56 +0200 Subject: [PATCH] Fix logger bug and prepare data bug (#1933) * tests, fix logger bug and prepare data bug * add CHANGELOG.md Co-authored-by: Nicki Skafte --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/trainer.py | 7 ++++++- tests/trainer/test_lr_finder.py | 23 +++++++++++++++++++++++ tests/trainer/test_trainer_tricks.py | 23 +++++++++++++++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d56f4a928..505c77ccfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873)) +- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933)) + ## [0.7.6] - 2020-05-16 ### Added diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3f51a168ad..a366890bc2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -401,6 +401,7 @@ class Trainer( self.auto_lr_find = auto_lr_find self.auto_scale_batch_size = auto_scale_batch_size + self._is_data_prepared = False self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps @@ -823,17 +824,21 @@ class Trainer( # download the data and do whatever transforms we need # do before any spawn calls so that the model can assign properties # only on proc 0 because no spawn has happened yet - model.prepare_data() + if not self._is_data_prepared: + model.prepare_data() + self._is_data_prepared = True # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): self.auto_scale_batch_size = 'power' self.scale_batch_size(model, mode=self.auto_scale_batch_size) + model.logger = self.logger # reset logger binding # Run learning rate finder: if self.auto_lr_find: self._run_lr_finder_internally(model) + model.logger = self.logger # reset logger binding # route to appropriate start method # when using multi-node or DDP within a node start each module in a separate process diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 67dd6a6f3d..d1e235b0a6 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -199,3 +199,26 @@ def test_suggestion_with_non_finite_values(tmpdir): assert before_lr == after_lr, \ 'Learning rate was altered because of non-finite loss values' + + +def test_logger_reset_correctly(tmpdir): + """ Test that logger is updated correctly """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(hparams) + + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=10, + auto_lr_find=True + ) + logger1 = trainer.logger + trainer.fit(model) + logger2 = trainer.logger + logger3 = model.logger + + assert logger1 == logger2, \ + 'Learning rate finder altered the logger of trainer' + assert logger2 == logger3, \ + 'Learning rate finder altered the logger of model' diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index a66e8bbde8..81eb7e1355 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -128,3 +128,26 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): with pytest.raises(MisconfigurationException): trainer.fit(model, **fit_options) + + +def test_logger_reset_correctly(tmpdir): + """ Test that logger is updated correctly """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(hparams) + + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + auto_scale_batch_size=True + ) + logger1 = trainer.logger + trainer.fit(model) + logger2 = trainer.logger + logger3 = model.logger + + assert logger1 == logger2, \ + 'Batch size finder altered the logger of trainer' + assert logger2 == logger3, \ + 'Batch size finder altered the logger of model'