Fixed configure optimizer from dict without "scheduler" key (#1443)
* `configure_optimizer` from dict with only "optimizer" key. bug fixed * autopep8 * pep8speaks suggested fixes * CHANGELOG.md upd
This commit is contained in:
parent
7857a73710
commit
4c34d16a34
|
@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443))
|
||||
- Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425))
|
||||
- Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430))
|
||||
- Fixed returning tuple from `run_training_batch` ([#1431](https://github.com/PyTorchLightning/pytorch-lightning/pull/1431))
|
||||
|
|
|
@ -39,6 +39,8 @@ class TrainerOptimizersMixin(ABC):
|
|||
lr_scheduler = optim_conf.get("lr_scheduler", [])
|
||||
if lr_scheduler:
|
||||
lr_schedulers = self.configure_schedulers([lr_scheduler])
|
||||
else:
|
||||
lr_schedulers = []
|
||||
return [optimizer], lr_schedulers, []
|
||||
|
||||
# multiple dictionaries
|
||||
|
|
|
@ -20,6 +20,7 @@ class TensorRunningAccum(object):
|
|||
>>> accum.last(), accum.mean(), accum.min(), accum.max()
|
||||
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
|
||||
"""
|
||||
|
||||
def __init__(self, window_length: int):
|
||||
self.window_length = window_length
|
||||
self.memory = torch.Tensor(self.window_length)
|
||||
|
|
|
@ -554,7 +554,8 @@ class Trainer(
|
|||
if at[0] not in depr_arg_names):
|
||||
for allowed_type in (at for at in allowed_types if at in arg_types):
|
||||
if isinstance(allowed_type, bool):
|
||||
allowed_type = lambda x: bool(distutils.util.strtobool(x))
|
||||
def allowed_type(x):
|
||||
return bool(distutils.util.strtobool(x))
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
default=arg_default,
|
||||
|
|
|
@ -275,3 +275,26 @@ def test_none_optimizer(tmpdir):
|
|||
|
||||
# verify training completed
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_configure_optimizer_from_dict(tmpdir):
|
||||
"""Tests if `configure_optimizer` method could return a dictionary with
|
||||
`optimizer` field only.
|
||||
"""
|
||||
|
||||
class CurrentTestModel(LightTrainDataloader, TestModelBase):
|
||||
def configure_optimizers(self):
|
||||
config = {
|
||||
'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03)
|
||||
}
|
||||
return config
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
assert result == 1
|
||||
|
|
Loading…
Reference in New Issue