Fixed configure optimizer from dict without "scheduler" key (#1443)

* `configure_optimizer` from dict with only "optimizer" key. bug fixed

* autopep8

* pep8speaks suggested fixes

* CHANGELOG.md upd
This commit is contained in:
Alexey Karnachev 2020-04-10 18:43:06 +03:00 committed by GitHub
parent 7857a73710
commit 4c34d16a34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 1 deletions

View File

@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443))
- Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425))
- Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430))
- Fixed returning tuple from `run_training_batch` ([#1431](https://github.com/PyTorchLightning/pytorch-lightning/pull/1431))

View File

@ -39,6 +39,8 @@ class TrainerOptimizersMixin(ABC):
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
else:
lr_schedulers = []
return [optimizer], lr_schedulers, []
# multiple dictionaries

View File

@ -20,6 +20,7 @@ class TensorRunningAccum(object):
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""
def __init__(self, window_length: int):
self.window_length = window_length
self.memory = torch.Tensor(self.window_length)

View File

@ -554,7 +554,8 @@ class Trainer(
if at[0] not in depr_arg_names):
for allowed_type in (at for at in allowed_types if at in arg_types):
if isinstance(allowed_type, bool):
allowed_type = lambda x: bool(distutils.util.strtobool(x))
def allowed_type(x):
return bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,

View File

@ -275,3 +275,26 @@ def test_none_optimizer(tmpdir):
# verify training completed
assert result == 1
def test_configure_optimizer_from_dict(tmpdir):
"""Tests if `configure_optimizer` method could return a dictionary with
`optimizer` field only.
"""
class CurrentTestModel(LightTrainDataloader, TestModelBase):
def configure_optimizers(self):
config = {
'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03)
}
return config
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1