Fix logger bug and prepare data bug (#1933)
* tests, fix logger bug and prepare data bug * add CHANGELOG.md Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
parent
033ddc0c29
commit
a34eb9e169
|
@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
|
||||
|
||||
- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
|
||||
|
||||
## [0.7.6] - 2020-05-16
|
||||
|
||||
### Added
|
||||
|
|
|
@ -401,6 +401,7 @@ class Trainer(
|
|||
|
||||
self.auto_lr_find = auto_lr_find
|
||||
self.auto_scale_batch_size = auto_scale_batch_size
|
||||
self._is_data_prepared = False
|
||||
self.replace_sampler_ddp = replace_sampler_ddp
|
||||
|
||||
self.truncated_bptt_steps = truncated_bptt_steps
|
||||
|
@ -823,17 +824,21 @@ class Trainer(
|
|||
# download the data and do whatever transforms we need
|
||||
# do before any spawn calls so that the model can assign properties
|
||||
# only on proc 0 because no spawn has happened yet
|
||||
if not self._is_data_prepared:
|
||||
model.prepare_data()
|
||||
self._is_data_prepared = True
|
||||
|
||||
# Run auto batch size scaling
|
||||
if self.auto_scale_batch_size:
|
||||
if isinstance(self.auto_scale_batch_size, bool):
|
||||
self.auto_scale_batch_size = 'power'
|
||||
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
|
||||
model.logger = self.logger # reset logger binding
|
||||
|
||||
# Run learning rate finder:
|
||||
if self.auto_lr_find:
|
||||
self._run_lr_finder_internally(model)
|
||||
model.logger = self.logger # reset logger binding
|
||||
|
||||
# route to appropriate start method
|
||||
# when using multi-node or DDP within a node start each module in a separate process
|
||||
|
|
|
@ -199,3 +199,26 @@ def test_suggestion_with_non_finite_values(tmpdir):
|
|||
|
||||
assert before_lr == after_lr, \
|
||||
'Learning rate was altered because of non-finite loss values'
|
||||
|
||||
|
||||
def test_logger_reset_correctly(tmpdir):
|
||||
""" Test that logger is updated correctly """
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(hparams)
|
||||
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=10,
|
||||
auto_lr_find=True
|
||||
)
|
||||
logger1 = trainer.logger
|
||||
trainer.fit(model)
|
||||
logger2 = trainer.logger
|
||||
logger3 = model.logger
|
||||
|
||||
assert logger1 == logger2, \
|
||||
'Learning rate finder altered the logger of trainer'
|
||||
assert logger2 == logger3, \
|
||||
'Learning rate finder altered the logger of model'
|
||||
|
|
|
@ -128,3 +128,26 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
|
|||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.fit(model, **fit_options)
|
||||
|
||||
|
||||
def test_logger_reset_correctly(tmpdir):
|
||||
""" Test that logger is updated correctly """
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(hparams)
|
||||
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=1,
|
||||
auto_scale_batch_size=True
|
||||
)
|
||||
logger1 = trainer.logger
|
||||
trainer.fit(model)
|
||||
logger2 = trainer.logger
|
||||
logger3 = model.logger
|
||||
|
||||
assert logger1 == logger2, \
|
||||
'Batch size finder altered the logger of trainer'
|
||||
assert logger2 == logger3, \
|
||||
'Batch size finder altered the logger of model'
|
||||
|
|
Loading…
Reference in New Issue