Fix logger bug and prepare data bug (#1933)

* tests, fix logger bug and prepare data bug

* add CHANGELOG.md

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
Nicki Skafte 2020-05-25 13:43:56 +02:00 committed by GitHub
parent 033ddc0c29
commit a34eb9e169
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 1 deletions

View File

@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
## [0.7.6] - 2020-05-16
### Added

View File

@ -401,6 +401,7 @@ class Trainer(
self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self._is_data_prepared = False
self.replace_sampler_ddp = replace_sampler_ddp
self.truncated_bptt_steps = truncated_bptt_steps
@ -823,17 +824,21 @@ class Trainer(
# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
model.prepare_data()
if not self._is_data_prepared:
model.prepare_data()
self._is_data_prepared = True
# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
model.logger = self.logger # reset logger binding
# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
model.logger = self.logger # reset logger binding
# route to appropriate start method
# when using multi-node or DDP within a node start each module in a separate process

View File

@ -199,3 +199,26 @@ def test_suggestion_with_non_finite_values(tmpdir):
assert before_lr == after_lr, \
'Learning rate was altered because of non-finite loss values'
def test_logger_reset_correctly(tmpdir):
""" Test that logger is updated correctly """
tutils.reset_seed()
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=10,
auto_lr_find=True
)
logger1 = trainer.logger
trainer.fit(model)
logger2 = trainer.logger
logger3 = model.logger
assert logger1 == logger2, \
'Learning rate finder altered the logger of trainer'
assert logger2 == logger3, \
'Learning rate finder altered the logger of model'

View File

@ -128,3 +128,26 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
with pytest.raises(MisconfigurationException):
trainer.fit(model, **fit_options)
def test_logger_reset_correctly(tmpdir):
""" Test that logger is updated correctly """
tutils.reset_seed()
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
auto_scale_batch_size=True
)
logger1 = trainer.logger
trainer.fit(model)
logger2 = trainer.logger
logger3 = model.logger
assert logger1 == logger2, \
'Batch size finder altered the logger of trainer'
assert logger2 == logger3, \
'Batch size finder altered the logger of model'