Fix logger bug and prepare data bug (#1933)
* tests, fix logger bug and prepare data bug * add CHANGELOG.md Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
parent
033ddc0c29
commit
a34eb9e169
|
@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
|
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
|
||||||
|
|
||||||
|
- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
|
||||||
|
|
||||||
## [0.7.6] - 2020-05-16
|
## [0.7.6] - 2020-05-16
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
@ -401,6 +401,7 @@ class Trainer(
|
||||||
|
|
||||||
self.auto_lr_find = auto_lr_find
|
self.auto_lr_find = auto_lr_find
|
||||||
self.auto_scale_batch_size = auto_scale_batch_size
|
self.auto_scale_batch_size = auto_scale_batch_size
|
||||||
|
self._is_data_prepared = False
|
||||||
self.replace_sampler_ddp = replace_sampler_ddp
|
self.replace_sampler_ddp = replace_sampler_ddp
|
||||||
|
|
||||||
self.truncated_bptt_steps = truncated_bptt_steps
|
self.truncated_bptt_steps = truncated_bptt_steps
|
||||||
|
@ -823,17 +824,21 @@ class Trainer(
|
||||||
# download the data and do whatever transforms we need
|
# download the data and do whatever transforms we need
|
||||||
# do before any spawn calls so that the model can assign properties
|
# do before any spawn calls so that the model can assign properties
|
||||||
# only on proc 0 because no spawn has happened yet
|
# only on proc 0 because no spawn has happened yet
|
||||||
model.prepare_data()
|
if not self._is_data_prepared:
|
||||||
|
model.prepare_data()
|
||||||
|
self._is_data_prepared = True
|
||||||
|
|
||||||
# Run auto batch size scaling
|
# Run auto batch size scaling
|
||||||
if self.auto_scale_batch_size:
|
if self.auto_scale_batch_size:
|
||||||
if isinstance(self.auto_scale_batch_size, bool):
|
if isinstance(self.auto_scale_batch_size, bool):
|
||||||
self.auto_scale_batch_size = 'power'
|
self.auto_scale_batch_size = 'power'
|
||||||
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
|
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
|
||||||
|
model.logger = self.logger # reset logger binding
|
||||||
|
|
||||||
# Run learning rate finder:
|
# Run learning rate finder:
|
||||||
if self.auto_lr_find:
|
if self.auto_lr_find:
|
||||||
self._run_lr_finder_internally(model)
|
self._run_lr_finder_internally(model)
|
||||||
|
model.logger = self.logger # reset logger binding
|
||||||
|
|
||||||
# route to appropriate start method
|
# route to appropriate start method
|
||||||
# when using multi-node or DDP within a node start each module in a separate process
|
# when using multi-node or DDP within a node start each module in a separate process
|
||||||
|
|
|
@ -199,3 +199,26 @@ def test_suggestion_with_non_finite_values(tmpdir):
|
||||||
|
|
||||||
assert before_lr == after_lr, \
|
assert before_lr == after_lr, \
|
||||||
'Learning rate was altered because of non-finite loss values'
|
'Learning rate was altered because of non-finite loss values'
|
||||||
|
|
||||||
|
|
||||||
|
def test_logger_reset_correctly(tmpdir):
|
||||||
|
""" Test that logger is updated correctly """
|
||||||
|
tutils.reset_seed()
|
||||||
|
|
||||||
|
hparams = EvalModelTemplate.get_default_hparams()
|
||||||
|
model = EvalModelTemplate(hparams)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_save_path=tmpdir,
|
||||||
|
max_epochs=10,
|
||||||
|
auto_lr_find=True
|
||||||
|
)
|
||||||
|
logger1 = trainer.logger
|
||||||
|
trainer.fit(model)
|
||||||
|
logger2 = trainer.logger
|
||||||
|
logger3 = model.logger
|
||||||
|
|
||||||
|
assert logger1 == logger2, \
|
||||||
|
'Learning rate finder altered the logger of trainer'
|
||||||
|
assert logger2 == logger3, \
|
||||||
|
'Learning rate finder altered the logger of model'
|
||||||
|
|
|
@ -128,3 +128,26 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
|
||||||
|
|
||||||
with pytest.raises(MisconfigurationException):
|
with pytest.raises(MisconfigurationException):
|
||||||
trainer.fit(model, **fit_options)
|
trainer.fit(model, **fit_options)
|
||||||
|
|
||||||
|
|
||||||
|
def test_logger_reset_correctly(tmpdir):
|
||||||
|
""" Test that logger is updated correctly """
|
||||||
|
tutils.reset_seed()
|
||||||
|
|
||||||
|
hparams = EvalModelTemplate.get_default_hparams()
|
||||||
|
model = EvalModelTemplate(hparams)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_save_path=tmpdir,
|
||||||
|
max_epochs=1,
|
||||||
|
auto_scale_batch_size=True
|
||||||
|
)
|
||||||
|
logger1 = trainer.logger
|
||||||
|
trainer.fit(model)
|
||||||
|
logger2 = trainer.logger
|
||||||
|
logger3 = model.logger
|
||||||
|
|
||||||
|
assert logger1 == logger2, \
|
||||||
|
'Batch size finder altered the logger of trainer'
|
||||||
|
assert logger2 == logger3, \
|
||||||
|
'Batch size finder altered the logger of model'
|
||||||
|
|
Loading…
Reference in New Issue