Add `ManualOptimization` loop (#9266)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
a79c351a6a
commit
ca679cd78f
|
@ -69,6 +69,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
|
||||
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
|
||||
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
|
||||
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
|
||||
|
||||
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
|
||||
|
||||
|
|
|
@ -63,6 +63,7 @@ ignore_errors = "True"
|
|||
module = [
|
||||
"pytorch_lightning.callbacks.pruning",
|
||||
"pytorch_lightning.loops.closure",
|
||||
"pytorch_lightning.loops.batch.manual",
|
||||
"pytorch_lightning.loops.optimizer",
|
||||
"pytorch_lightning.trainer.evaluation_loop",
|
||||
"pytorch_lightning.trainer.connectors.logger_connector.*",
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.base import Loop # noqa: F401
|
||||
from pytorch_lightning.loops.batch import ManualOptimization # noqa: F401
|
||||
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
|
||||
|
|
|
@ -12,4 +12,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
|
||||
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_build_training_step_kwargs,
|
||||
_check_training_step_output,
|
||||
_process_training_step_output,
|
||||
)
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
|
||||
|
||||
class ManualOptimization(Loop):
|
||||
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
|
||||
entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user
|
||||
is responsible for back-propagating gradients and making calls to the optimizers.
|
||||
|
||||
This loop is a trivial case because it performs only a single iteration (calling directly into the module's
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`) and passing through the output(s).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._done: bool = False
|
||||
self._hiddens: Optional[Any] = None
|
||||
self._output: Optional[ResultCollection] = None
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
def reset(self) -> None:
|
||||
self._done = False
|
||||
|
||||
def advance(self, batch: Any, batch_idx: int, hiddens: Optional[Any] = None) -> None: # type: ignore[override]
|
||||
"""Performs the training step for manual optimization.
|
||||
|
||||
Args:
|
||||
batch: the current tbptt split of the current batch
|
||||
batch_idx: the index of the current batch
|
||||
hiddens: the model's hidden state of the previous iteration
|
||||
"""
|
||||
assert self.trainer is not None
|
||||
ligtning_module = self.trainer.lightning_module
|
||||
|
||||
with self.trainer.profiler.profile("model_forward"):
|
||||
|
||||
step_kwargs = _build_training_step_kwargs(
|
||||
ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=hiddens
|
||||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
ligtning_module._current_fx_name = "training_step"
|
||||
with self.trainer.profiler.profile("training_step"):
|
||||
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
|
||||
self.trainer.accelerator.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
||||
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
|
||||
|
||||
_check_training_step_output(ligtning_module, training_step_output)
|
||||
|
||||
result_collection, hiddens = _process_training_step_output(self.trainer, training_step_output)
|
||||
|
||||
self._done = True
|
||||
self._hiddens = hiddens
|
||||
self._output = result_collection
|
||||
|
||||
def on_run_end(self) -> Tuple[Optional[ResultCollection], Optional[Any]]:
|
||||
"""Returns the result of this loop, i.e., the post-processed outputs from the training step, and the hidden
|
||||
state."""
|
||||
output = self._output
|
||||
hiddens = self._hiddens
|
||||
self._output, self._hiddens = None, None # free memory
|
||||
return output, hiddens
|
|
@ -12,8 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from deprecate import void
|
||||
|
@ -21,13 +20,8 @@ from torch import Tensor
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
||||
from pytorch_lightning.loops.batch.manual import ManualOptimization
|
||||
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_build_training_step_kwargs,
|
||||
_check_training_step_output,
|
||||
_process_training_step_output,
|
||||
)
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
@ -45,6 +39,7 @@ class TrainingBatchLoop(Loop):
|
|||
# the current split index when the batch gets split into chunks in truncated backprop through time
|
||||
self.split_idx: Optional[int] = None
|
||||
self.optimizer_loop = OptimizerLoop()
|
||||
self.manual_loop = ManualOptimization()
|
||||
|
||||
self._warning_cache: WarningCache = WarningCache()
|
||||
self._hiddens: Optional[Tensor] = None
|
||||
|
@ -63,8 +58,13 @@ class TrainingBatchLoop(Loop):
|
|||
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
|
||||
return self._optimizer_freq_cumsum
|
||||
|
||||
def connect(self, optimizer_loop: "Loop") -> None:
|
||||
self.optimizer_loop = optimizer_loop
|
||||
def connect(
|
||||
self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None
|
||||
) -> None:
|
||||
if optimizer_loop is not None:
|
||||
self.optimizer_loop = optimizer_loop
|
||||
if manual_loop is not None:
|
||||
self.manual_loop = manual_loop
|
||||
|
||||
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
|
||||
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
|
||||
|
@ -132,10 +132,10 @@ class TrainingBatchLoop(Loop):
|
|||
for k in range(len(batch_outputs)):
|
||||
self.batch_outputs[k].extend(batch_outputs[k])
|
||||
else:
|
||||
# in manual optimization, there is no looping over optimizers
|
||||
result = self._run_optimization(batch_idx, split_batch)
|
||||
if result:
|
||||
self.batch_outputs[0].append(deepcopy(result.result_collection))
|
||||
# in manual optimization, hand over execution to the ManualOptimization loop
|
||||
output, self._hiddens = self.manual_loop.run(split_batch, batch_idx, self._hiddens)
|
||||
if output:
|
||||
self.batch_outputs[0].append(deepcopy(output))
|
||||
|
||||
def teardown(self) -> None:
|
||||
# release memory
|
||||
|
@ -145,89 +145,6 @@ class TrainingBatchLoop(Loop):
|
|||
"""Gets the number of active optimizers based on their frequency."""
|
||||
return len(self.get_active_optimizers(batch_idx))
|
||||
|
||||
def _run_optimization(
|
||||
self,
|
||||
batch_idx: int,
|
||||
split_batch: Any,
|
||||
) -> Optional[ClosureResult]:
|
||||
"""Runs closure (train step + backward) together with optimization if necessary.
|
||||
|
||||
Args:
|
||||
batch_idx: the index of the current batch
|
||||
split_batch: the current tbptt split of the whole batch
|
||||
"""
|
||||
# TODO: replace call through closure by direct call (manual optimization)
|
||||
closure = self._make_closure(split_batch, batch_idx, self._hiddens)
|
||||
closure()
|
||||
result = closure.get_result()
|
||||
|
||||
if result:
|
||||
# if no result, user decided to skip optimization
|
||||
# otherwise update running loss + reset accumulated loss
|
||||
self._update_running_loss(result.loss)
|
||||
|
||||
return result
|
||||
|
||||
def _make_closure(
|
||||
self,
|
||||
split_batch: Any,
|
||||
batch_idx: int,
|
||||
hiddens: Any,
|
||||
) -> Closure:
|
||||
"""Build a closure object that captures the given arguments and runs the `training_step` function and
|
||||
optionally other functions such as `backward` and `zero_grad`."""
|
||||
step_fn = self._make_step_fn(split_batch, batch_idx, hiddens)
|
||||
backward_fn = None
|
||||
zero_grad_fn = None
|
||||
|
||||
return Closure(
|
||||
step_fn=step_fn,
|
||||
backward_fn=backward_fn,
|
||||
zero_grad_fn=zero_grad_fn,
|
||||
profiler=self.trainer.profiler,
|
||||
)
|
||||
|
||||
def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]:
|
||||
"""Build the step function that runs the `training_step` and processes its output."""
|
||||
return partial(self._training_step, split_batch, batch_idx, hiddens)
|
||||
|
||||
def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
|
||||
"""Performs the training step for manual optimization.
|
||||
|
||||
Args:
|
||||
split_batch: the current tbptt split of the current batch
|
||||
batch_idx: the index of the current batch
|
||||
hiddens: the model's hidden state of the previous iteration
|
||||
|
||||
Returns:
|
||||
an AttributeDict containing the training step output.
|
||||
"""
|
||||
# give the PL module a result for logging
|
||||
model_ref = self.trainer.lightning_module
|
||||
|
||||
with self.trainer.profiler.profile("model_forward"):
|
||||
step_kwargs = _build_training_step_kwargs(
|
||||
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens
|
||||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
model_ref._current_fx_name = "training_step"
|
||||
with self.trainer.profiler.profile("training_step"):
|
||||
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
|
||||
self.trainer.accelerator.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
||||
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
|
||||
|
||||
_check_training_step_output(self.trainer.lightning_module, training_step_output)
|
||||
|
||||
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
|
||||
if result_collection is None:
|
||||
return
|
||||
|
||||
return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection)
|
||||
|
||||
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
|
||||
"""Splits a single batch into a list of sequence steps for tbptt.
|
||||
|
||||
|
|
|
@ -56,8 +56,9 @@ def test_loops_state_dict_structure():
|
|||
"total": {"ready": 0, "completed": 0},
|
||||
"current": {"ready": 0, "completed": 0},
|
||||
},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer": {
|
||||
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
|
||||
|
|
|
@ -464,6 +464,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
|
||||
},
|
||||
"epoch_loop.batch_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer_idx": stop_optimizer,
|
||||
|
|
|
@ -14,82 +14,97 @@
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
class LinearModel(BoringModel):
|
||||
"""Linear model for testing TBPTT with automatic optimization."""
|
||||
|
||||
def __init__(self, truncated_bptt_steps=2, n_hidden_states=1, sequence_size=30, batch_size=30):
|
||||
super().__init__()
|
||||
self.truncated_bptt_steps = truncated_bptt_steps
|
||||
self.n_hidden_states = n_hidden_states
|
||||
self.sequence_size = sequence_size
|
||||
self.batch_size = batch_size
|
||||
self.automatic_optimization = True
|
||||
|
||||
self.example_input_array = torch.randn(5, truncated_bptt_steps)
|
||||
self.layer = torch.nn.Linear(in_features=truncated_bptt_steps, out_features=truncated_bptt_steps)
|
||||
self.test_hidden = None
|
||||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
|
||||
if self.n_hidden_states == 1:
|
||||
self.test_hidden = torch.rand(1)
|
||||
else:
|
||||
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)
|
||||
|
||||
x_tensor, y_list = batch
|
||||
assert x_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split Tensor failed"
|
||||
|
||||
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
|
||||
assert y_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split list failed"
|
||||
|
||||
pred = self(x_tensor.view(self.batch_size, self.truncated_bptt_steps))
|
||||
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(self.batch_size, self.truncated_bptt_steps))
|
||||
return {"loss": loss_val, "hiddens": self.test_hidden}
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
training_step_outputs = training_step_outputs[0]
|
||||
assert len(training_step_outputs) == (self.sequence_size / self.truncated_bptt_steps)
|
||||
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
|
||||
assert loss.grad_fn is None
|
||||
self.log("train_loss", loss)
|
||||
|
||||
|
||||
class ManualLinearModel(LinearModel):
|
||||
"""Linear model for testing TBPTT with manual optimization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
out = super().training_step(batch, batch_idx, hiddens)
|
||||
loss, hiddens = out["loss"], out["hiddens"]
|
||||
opt = self.optimizers()
|
||||
opt.zero_grad()
|
||||
self.manual_backward(loss)
|
||||
opt.step()
|
||||
assert loss.grad_fn is not None
|
||||
return {"loss": loss, "hiddens": hiddens}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", (LinearModel, ManualLinearModel))
|
||||
@pytest.mark.parametrize("n_hidden_states", (1, 2))
|
||||
def test_tbptt_cpu_model(tmpdir, n_hidden_states):
|
||||
"""Test truncated back propagation through time works."""
|
||||
truncated_bptt_steps = 2
|
||||
def test_tbptt_cpu_model_manual(tmpdir, n_hidden_states, model_class):
|
||||
"""Test truncated back propagation through time works with automatic and manual optimization."""
|
||||
|
||||
sequence_size = 30
|
||||
batch_size = 30
|
||||
|
||||
x_seq = torch.rand(batch_size, sequence_size, 1)
|
||||
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
|
||||
|
||||
class MockSeq2SeqDataset(torch.utils.data.Dataset):
|
||||
class MockSeq2SeqDataset(Dataset):
|
||||
def __getitem__(self, i):
|
||||
return x_seq, y_seq_list
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
class BpttTestModel(BoringModel):
|
||||
def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.test_hidden = None
|
||||
self.batch_size = batch_size
|
||||
self.layer = torch.nn.Linear(in_features, out_features)
|
||||
self.n_hidden_states = n_hidden_states
|
||||
self.truncated_bptt_steps = truncated_bptt_steps
|
||||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
|
||||
if self.n_hidden_states == 1:
|
||||
self.test_hidden = torch.rand(1)
|
||||
else:
|
||||
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)
|
||||
|
||||
x_tensor, y_list = batch
|
||||
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
|
||||
|
||||
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
|
||||
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
|
||||
|
||||
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
|
||||
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
|
||||
return {"loss": loss_val, "hiddens": self.test_hidden}
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
training_step_outputs = training_step_outputs[0]
|
||||
assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps)
|
||||
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
|
||||
self.log("train_loss", loss)
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
|
||||
)
|
||||
|
||||
model = BpttTestModel(
|
||||
batch_size=batch_size,
|
||||
in_features=truncated_bptt_steps,
|
||||
out_features=truncated_bptt_steps,
|
||||
n_hidden_states=n_hidden_states,
|
||||
)
|
||||
model.example_input_array = torch.randn(5, truncated_bptt_steps)
|
||||
|
||||
# fit model
|
||||
train_dataloader = DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False)
|
||||
model = model_class(n_hidden_states=n_hidden_states, sequence_size=sequence_size, batch_size=batch_size)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"
|
||||
trainer.fit(model, train_dataloader)
|
||||
|
||||
|
||||
def test_tbptt_log(tmpdir):
|
||||
|
@ -98,7 +113,7 @@ def test_tbptt_log(tmpdir):
|
|||
batch_size = 10
|
||||
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"
|
||||
|
||||
class MockSeq2SeqDataset(torch.utils.data.Dataset):
|
||||
class MockSeq2SeqDataset(Dataset):
|
||||
def __init__(self):
|
||||
self.x_seq = torch.randn(N, T, F)
|
||||
self.y_seq = torch.randn(N, T, F)
|
||||
|
@ -143,7 +158,7 @@ def test_tbptt_log(tmpdir):
|
|||
self.test_hidden = None
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)
|
||||
return DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
|
|
Loading…
Reference in New Issue