Fixed configure optimizer from dict without "scheduler" key (#1443)
* `configure_optimizer` from dict with only "optimizer" key. bug fixed * autopep8 * pep8speaks suggested fixes * CHANGELOG.md upd
This commit is contained in:
parent
7857a73710
commit
4c34d16a34
|
@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443))
|
||||||
- Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425))
|
- Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425))
|
||||||
- Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430))
|
- Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430))
|
||||||
- Fixed returning tuple from `run_training_batch` ([#1431](https://github.com/PyTorchLightning/pytorch-lightning/pull/1431))
|
- Fixed returning tuple from `run_training_batch` ([#1431](https://github.com/PyTorchLightning/pytorch-lightning/pull/1431))
|
||||||
|
|
|
@ -39,6 +39,8 @@ class TrainerOptimizersMixin(ABC):
|
||||||
lr_scheduler = optim_conf.get("lr_scheduler", [])
|
lr_scheduler = optim_conf.get("lr_scheduler", [])
|
||||||
if lr_scheduler:
|
if lr_scheduler:
|
||||||
lr_schedulers = self.configure_schedulers([lr_scheduler])
|
lr_schedulers = self.configure_schedulers([lr_scheduler])
|
||||||
|
else:
|
||||||
|
lr_schedulers = []
|
||||||
return [optimizer], lr_schedulers, []
|
return [optimizer], lr_schedulers, []
|
||||||
|
|
||||||
# multiple dictionaries
|
# multiple dictionaries
|
||||||
|
|
|
@ -20,6 +20,7 @@ class TensorRunningAccum(object):
|
||||||
>>> accum.last(), accum.mean(), accum.min(), accum.max()
|
>>> accum.last(), accum.mean(), accum.min(), accum.max()
|
||||||
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
|
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_length: int):
|
def __init__(self, window_length: int):
|
||||||
self.window_length = window_length
|
self.window_length = window_length
|
||||||
self.memory = torch.Tensor(self.window_length)
|
self.memory = torch.Tensor(self.window_length)
|
||||||
|
|
|
@ -554,7 +554,8 @@ class Trainer(
|
||||||
if at[0] not in depr_arg_names):
|
if at[0] not in depr_arg_names):
|
||||||
for allowed_type in (at for at in allowed_types if at in arg_types):
|
for allowed_type in (at for at in allowed_types if at in arg_types):
|
||||||
if isinstance(allowed_type, bool):
|
if isinstance(allowed_type, bool):
|
||||||
allowed_type = lambda x: bool(distutils.util.strtobool(x))
|
def allowed_type(x):
|
||||||
|
return bool(distutils.util.strtobool(x))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f'--{arg}',
|
f'--{arg}',
|
||||||
default=arg_default,
|
default=arg_default,
|
||||||
|
|
|
@ -275,3 +275,26 @@ def test_none_optimizer(tmpdir):
|
||||||
|
|
||||||
# verify training completed
|
# verify training completed
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_optimizer_from_dict(tmpdir):
|
||||||
|
"""Tests if `configure_optimizer` method could return a dictionary with
|
||||||
|
`optimizer` field only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class CurrentTestModel(LightTrainDataloader, TestModelBase):
|
||||||
|
def configure_optimizers(self):
|
||||||
|
config = {
|
||||||
|
'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03)
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
hparams = tutils.get_default_hparams()
|
||||||
|
model = CurrentTestModel(hparams)
|
||||||
|
|
||||||
|
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
|
||||||
|
|
||||||
|
# fit model
|
||||||
|
trainer = Trainer(**trainer_options)
|
||||||
|
result = trainer.fit(model)
|
||||||
|
assert result == 1
|
||||||
|
|
Loading…
Reference in New Issue