Add `ManualOptimization` loop (#9266)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2021-09-08 02:26:39 +02:00 committed by GitHub
parent a79c351a6a
commit ca679cd78f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 179 additions and 152 deletions

View File

@ -69,6 +69,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))

View File

@ -63,6 +63,7 @@ ignore_errors = "True"
module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.loops.closure",
"pytorch_lightning.loops.batch.manual",
"pytorch_lightning.loops.optimizer",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector.*",

View File

@ -13,6 +13,7 @@
# limitations under the License.
from pytorch_lightning.loops.base import Loop # noqa: F401
from pytorch_lightning.loops.batch import ManualOptimization # noqa: F401
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401

View File

@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401

View File

@ -0,0 +1,89 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
)
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
class ManualOptimization(Loop):
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user
is responsible for back-propagating gradients and making calls to the optimizers.
This loop is a trivial case because it performs only a single iteration (calling directly into the module's
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`) and passing through the output(s).
"""
def __init__(self) -> None:
super().__init__()
self._done: bool = False
self._hiddens: Optional[Any] = None
self._output: Optional[ResultCollection] = None
@property
def done(self) -> bool:
return self._done
def reset(self) -> None:
self._done = False
def advance(self, batch: Any, batch_idx: int, hiddens: Optional[Any] = None) -> None: # type: ignore[override]
"""Performs the training step for manual optimization.
Args:
batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
hiddens: the model's hidden state of the previous iteration
"""
assert self.trainer is not None
ligtning_module = self.trainer.lightning_module
with self.trainer.profiler.profile("model_forward"):
step_kwargs = _build_training_step_kwargs(
ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=hiddens
)
# manually capture logged metrics
ligtning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
del step_kwargs
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
_check_training_step_output(ligtning_module, training_step_output)
result_collection, hiddens = _process_training_step_output(self.trainer, training_step_output)
self._done = True
self._hiddens = hiddens
self._output = result_collection
def on_run_end(self) -> Tuple[Optional[ResultCollection], Optional[Any]]:
"""Returns the result of this loop, i.e., the post-processed outputs from the training step, and the hidden
state."""
output = self._output
hiddens = self._hiddens
self._output, self._hiddens = None, None # free memory
return output, hiddens

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from functools import partial
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
import numpy as np
from deprecate import void
@ -21,13 +20,8 @@ from torch import Tensor
from torch.optim import Optimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.batch.manual import ManualOptimization
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
)
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -45,6 +39,7 @@ class TrainingBatchLoop(Loop):
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: Optional[int] = None
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()
self._warning_cache: WarningCache = WarningCache()
self._hiddens: Optional[Tensor] = None
@ -63,8 +58,13 @@ class TrainingBatchLoop(Loop):
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum
def connect(self, optimizer_loop: "Loop") -> None:
self.optimizer_loop = optimizer_loop
def connect(
self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None
) -> None:
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
if manual_loop is not None:
self.manual_loop = manual_loop
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
@ -132,10 +132,10 @@ class TrainingBatchLoop(Loop):
for k in range(len(batch_outputs)):
self.batch_outputs[k].extend(batch_outputs[k])
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
if result:
self.batch_outputs[0].append(deepcopy(result.result_collection))
# in manual optimization, hand over execution to the ManualOptimization loop
output, self._hiddens = self.manual_loop.run(split_batch, batch_idx, self._hiddens)
if output:
self.batch_outputs[0].append(deepcopy(output))
def teardown(self) -> None:
# release memory
@ -145,89 +145,6 @@ class TrainingBatchLoop(Loop):
"""Gets the number of active optimizers based on their frequency."""
return len(self.get_active_optimizers(batch_idx))
def _run_optimization(
self,
batch_idx: int,
split_batch: Any,
) -> Optional[ClosureResult]:
"""Runs closure (train step + backward) together with optimization if necessary.
Args:
batch_idx: the index of the current batch
split_batch: the current tbptt split of the whole batch
"""
# TODO: replace call through closure by direct call (manual optimization)
closure = self._make_closure(split_batch, batch_idx, self._hiddens)
closure()
result = closure.get_result()
if result:
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
self._update_running_loss(result.loss)
return result
def _make_closure(
self,
split_batch: Any,
batch_idx: int,
hiddens: Any,
) -> Closure:
"""Build a closure object that captures the given arguments and runs the `training_step` function and
optionally other functions such as `backward` and `zero_grad`."""
step_fn = self._make_step_fn(split_batch, batch_idx, hiddens)
backward_fn = None
zero_grad_fn = None
return Closure(
step_fn=step_fn,
backward_fn=backward_fn,
zero_grad_fn=zero_grad_fn,
profiler=self.trainer.profiler,
)
def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]:
"""Build the step function that runs the `training_step` and processes its output."""
return partial(self._training_step, split_batch, batch_idx, hiddens)
def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
"""Performs the training step for manual optimization.
Args:
split_batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
hiddens: the model's hidden state of the previous iteration
Returns:
an AttributeDict containing the training step output.
"""
# give the PL module a result for logging
model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("model_forward"):
step_kwargs = _build_training_step_kwargs(
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens
)
# manually capture logged metrics
model_ref._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
del step_kwargs
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
_check_training_step_output(self.trainer.lightning_module, training_step_output)
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
if result_collection is None:
return
return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection)
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.

View File

@ -56,8 +56,9 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},

View File

@ -464,6 +464,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_idx": stop_optimizer,

View File

@ -14,82 +14,97 @@
import pytest
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
class LinearModel(BoringModel):
"""Linear model for testing TBPTT with automatic optimization."""
def __init__(self, truncated_bptt_steps=2, n_hidden_states=1, sequence_size=30, batch_size=30):
super().__init__()
self.truncated_bptt_steps = truncated_bptt_steps
self.n_hidden_states = n_hidden_states
self.sequence_size = sequence_size
self.batch_size = batch_size
self.automatic_optimization = True
self.example_input_array = torch.randn(5, truncated_bptt_steps)
self.layer = torch.nn.Linear(in_features=truncated_bptt_steps, out_features=truncated_bptt_steps)
self.test_hidden = None
def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if self.n_hidden_states == 1:
self.test_hidden = torch.rand(1)
else:
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)
x_tensor, y_list = batch
assert x_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split Tensor failed"
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split list failed"
pred = self(x_tensor.view(self.batch_size, self.truncated_bptt_steps))
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(self.batch_size, self.truncated_bptt_steps))
return {"loss": loss_val, "hiddens": self.test_hidden}
def training_epoch_end(self, training_step_outputs):
training_step_outputs = training_step_outputs[0]
assert len(training_step_outputs) == (self.sequence_size / self.truncated_bptt_steps)
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
assert loss.grad_fn is None
self.log("train_loss", loss)
class ManualLinearModel(LinearModel):
"""Linear model for testing TBPTT with manual optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, batch, batch_idx, hiddens):
out = super().training_step(batch, batch_idx, hiddens)
loss, hiddens = out["loss"], out["hiddens"]
opt = self.optimizers()
opt.zero_grad()
self.manual_backward(loss)
opt.step()
assert loss.grad_fn is not None
return {"loss": loss, "hiddens": hiddens}
@pytest.mark.parametrize("model_class", (LinearModel, ManualLinearModel))
@pytest.mark.parametrize("n_hidden_states", (1, 2))
def test_tbptt_cpu_model(tmpdir, n_hidden_states):
"""Test truncated back propagation through time works."""
truncated_bptt_steps = 2
def test_tbptt_cpu_model_manual(tmpdir, n_hidden_states, model_class):
"""Test truncated back propagation through time works with automatic and manual optimization."""
sequence_size = 30
batch_size = 30
x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
class MockSeq2SeqDataset(torch.utils.data.Dataset):
class MockSeq2SeqDataset(Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list
def __len__(self):
return 1
class BpttTestModel(BoringModel):
def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_hidden = None
self.batch_size = batch_size
self.layer = torch.nn.Linear(in_features, out_features)
self.n_hidden_states = n_hidden_states
self.truncated_bptt_steps = truncated_bptt_steps
def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if self.n_hidden_states == 1:
self.test_hidden = torch.rand(1)
else:
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)
x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
return {"loss": loss_val, "hiddens": self.test_hidden}
def training_epoch_end(self, training_step_outputs):
training_step_outputs = training_step_outputs[0]
assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps)
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
self.log("train_loss", loss)
def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
)
model = BpttTestModel(
batch_size=batch_size,
in_features=truncated_bptt_steps,
out_features=truncated_bptt_steps,
n_hidden_states=n_hidden_states,
)
model.example_input_array = torch.randn(5, truncated_bptt_steps)
# fit model
train_dataloader = DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False)
model = model_class(n_hidden_states=n_hidden_states, sequence_size=sequence_size, batch_size=batch_size)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0,
weights_summary=None,
)
trainer.fit(model)
assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"
trainer.fit(model, train_dataloader)
def test_tbptt_log(tmpdir):
@ -98,7 +113,7 @@ def test_tbptt_log(tmpdir):
batch_size = 10
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"
class MockSeq2SeqDataset(torch.utils.data.Dataset):
class MockSeq2SeqDataset(Dataset):
def __init__(self):
self.x_seq = torch.randn(N, T, F)
self.y_seq = torch.randn(N, T, F)
@ -143,7 +158,7 @@ def test_tbptt_log(tmpdir):
self.test_hidden = None
def train_dataloader(self):
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)
return DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)
model = TestModel()
model.training_epoch_end = None