Fix batch_outputs with optimizer frequencies (#3229)

* Fix batch_outputs with optimizers frequencies

* optimizers

* fix batch_outputs with optimizer frequencies

* clean test

* suggestion

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* chlog

* failing doctest

* failing doctest

* update doctest

* chlog

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Rohit Gupta 2020-09-11 02:31:20 +05:30 committed by GitHub
parent 3281586ab4
commit a1ea681c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 9 deletions

View File

@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `GpuUsageLogger` to work on different platforms ([#3008](https://github.com/PyTorchLightning/pytorch-lightning/pull/3008))
- Fixed batch_outputs with optimizer frequencies ([#3229](https://github.com/PyTorchLightning/pytorch-lightning/pull/3229))
- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266))
- Fixed Horovod distributed backend compatibility with native AMP ([#3404](https://github.com/PyTorchLightning/pytorch-lightning/pull/3404))

View File

@ -16,9 +16,9 @@ To enable your code to work with Lightning, here's how to organize PyTorch into
===============================
Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.LightningModule`.
.. code-block::
.. testcode::
class LitModel(pl.LightningModule):
class LitModel(LightningModule):
def __init__(self):
super().__init__()
@ -36,9 +36,9 @@ Move the model architecture and forward pass to your :class:`~pytorch_lightning.
=======================================
Move your optimizers to :func:`pytorch_lightning.core.LightningModule.configure_optimizers` hook. Make sure to use the hook parameters (self in this case).
.. code-block::
.. testcode::
class LitModel(pl.LightningModule):
class LitModel(LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
@ -48,9 +48,9 @@ Move your optimizers to :func:`pytorch_lightning.core.LightningModule.configure_
=============================
Lightning automates most of the trining for you, the epoch and batch iterations, all you need to keep is the training step logic. This should go into :func:`pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, self in this case):
.. code-block::
.. testcode::
class LitModel(pl.LightningModule):
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
@ -78,9 +78,9 @@ To add an (optional) validation loop add logic to :func:`pytorch_lightning.core.
============================
To add an (optional) test loop add logic to :func:`pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, self in this case).
.. code-block::
.. testcode::
class LitModel(pl.LightningModule):
class LitModel(LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch

View File

@ -500,7 +500,8 @@ class TrainLoop:
self.accumulated_loss.append(opt_closure_result.loss)
# track all the outputs across all steps
batch_outputs[opt_idx].append(opt_closure_result.training_step_output_for_epoch_end)
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end)
# ------------------------------
# BACKWARD PASS

View File

@ -34,6 +34,14 @@ class ConfigureOptimizersPool(ABC):
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer1, optimizer2
def configure_optimizers__multiple_optimizers_frequency(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
return [
{'optimizer': optimizer1, 'frequency': 1},
{'optimizer': optimizer2, 'frequency': 5}
]
def configure_optimizers__single_scheduler(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)

View File

@ -240,3 +240,18 @@ def test_configure_optimizer_from_dict(tmpdir):
)
result = trainer.fit(model)
assert result == 1
def test_configure_optimizers_with_frequency(tmpdir):
"""
Test that multiple optimizers work when corresponding frequency is set.
"""
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_optimizers_frequency
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1
)
result = trainer.fit(model)
assert result