From ca679cd78f7920cc2080702d299edb33de1bd634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Sep 2021 02:26:39 +0200 Subject: [PATCH] Add `ManualOptimization` loop (#9266) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: thomas chaton --- CHANGELOG.md | 1 + pyproject.toml | 1 + pytorch_lightning/loops/__init__.py | 1 + pytorch_lightning/loops/batch/__init__.py | 1 + pytorch_lightning/loops/batch/manual.py | 89 +++++++++++++ .../loops/batch/training_batch_loop.py | 111 ++-------------- tests/loops/test_loop_state_dict.py | 3 +- tests/loops/test_loops.py | 1 + tests/models/test_truncated_bptt.py | 123 ++++++++++-------- 9 files changed, 179 insertions(+), 152 deletions(-) create mode 100644 pytorch_lightning/loops/batch/manual.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b39e4dab4a..37f7b390ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642)) * Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191)) * Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265)) + * Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266)) - Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) diff --git a/pyproject.toml b/pyproject.toml index eca6d81e06..11f924d51b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ ignore_errors = "True" module = [ "pytorch_lightning.callbacks.pruning", "pytorch_lightning.loops.closure", + "pytorch_lightning.loops.batch.manual", "pytorch_lightning.loops.optimizer", "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector.*", diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index 3886a21c65..dd0878b3ea 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from pytorch_lightning.loops.base import Loop # noqa: F401 +from pytorch_lightning.loops.batch import ManualOptimization # noqa: F401 from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401 from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401 from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401 diff --git a/pytorch_lightning/loops/batch/__init__.py b/pytorch_lightning/loops/batch/__init__.py index 6e65221654..6a5b927ab6 100644 --- a/pytorch_lightning/loops/batch/__init__.py +++ b/pytorch_lightning/loops/batch/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401 from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401 diff --git a/pytorch_lightning/loops/batch/manual.py b/pytorch_lightning/loops/batch/manual.py new file mode 100644 index 0000000000..65f203a11c --- /dev/null +++ b/pytorch_lightning/loops/batch/manual.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Tuple + +from pytorch_lightning.loops import Loop +from pytorch_lightning.loops.utilities import ( + _build_training_step_kwargs, + _check_training_step_output, + _process_training_step_output, +) +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection + + +class ManualOptimization(Loop): + """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens + entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user + is responsible for back-propagating gradients and making calls to the optimizers. + + This loop is a trivial case because it performs only a single iteration (calling directly into the module's + :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`) and passing through the output(s). + """ + + def __init__(self) -> None: + super().__init__() + self._done: bool = False + self._hiddens: Optional[Any] = None + self._output: Optional[ResultCollection] = None + + @property + def done(self) -> bool: + return self._done + + def reset(self) -> None: + self._done = False + + def advance(self, batch: Any, batch_idx: int, hiddens: Optional[Any] = None) -> None: # type: ignore[override] + """Performs the training step for manual optimization. + + Args: + batch: the current tbptt split of the current batch + batch_idx: the index of the current batch + hiddens: the model's hidden state of the previous iteration + """ + assert self.trainer is not None + ligtning_module = self.trainer.lightning_module + + with self.trainer.profiler.profile("model_forward"): + + step_kwargs = _build_training_step_kwargs( + ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=hiddens + ) + + # manually capture logged metrics + ligtning_module._current_fx_name = "training_step" + with self.trainer.profiler.profile("training_step"): + training_step_output = self.trainer.accelerator.training_step(step_kwargs) + self.trainer.accelerator.post_training_step() + + del step_kwargs + + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + + _check_training_step_output(ligtning_module, training_step_output) + + result_collection, hiddens = _process_training_step_output(self.trainer, training_step_output) + + self._done = True + self._hiddens = hiddens + self._output = result_collection + + def on_run_end(self) -> Tuple[Optional[ResultCollection], Optional[Any]]: + """Returns the result of this loop, i.e., the post-processed outputs from the training step, and the hidden + state.""" + output = self._output + hiddens = self._hiddens + self._output, self._hiddens = None, None # free memory + return output, hiddens diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index c2757e4035..942331f0e1 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from functools import partial -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, List, Optional, Tuple import numpy as np from deprecate import void @@ -21,13 +20,8 @@ from torch import Tensor from torch.optim import Optimizer from pytorch_lightning.loops.base import Loop -from pytorch_lightning.loops.closure import Closure, ClosureResult +from pytorch_lightning.loops.batch.manual import ManualOptimization from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop -from pytorch_lightning.loops.utilities import ( - _build_training_step_kwargs, - _check_training_step_output, - _process_training_step_output, -) from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -45,6 +39,7 @@ class TrainingBatchLoop(Loop): # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None self.optimizer_loop = OptimizerLoop() + self.manual_loop = ManualOptimization() self._warning_cache: WarningCache = WarningCache() self._hiddens: Optional[Tensor] = None @@ -63,8 +58,13 @@ class TrainingBatchLoop(Loop): self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum - def connect(self, optimizer_loop: "Loop") -> None: - self.optimizer_loop = optimizer_loop + def connect( + self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None + ) -> None: + if optimizer_loop is not None: + self.optimizer_loop = optimizer_loop + if manual_loop is not None: + self.manual_loop = manual_loop def run(self, batch: Any, batch_idx: int) -> AttributeDict: """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks. @@ -132,10 +132,10 @@ class TrainingBatchLoop(Loop): for k in range(len(batch_outputs)): self.batch_outputs[k].extend(batch_outputs[k]) else: - # in manual optimization, there is no looping over optimizers - result = self._run_optimization(batch_idx, split_batch) - if result: - self.batch_outputs[0].append(deepcopy(result.result_collection)) + # in manual optimization, hand over execution to the ManualOptimization loop + output, self._hiddens = self.manual_loop.run(split_batch, batch_idx, self._hiddens) + if output: + self.batch_outputs[0].append(deepcopy(output)) def teardown(self) -> None: # release memory @@ -145,89 +145,6 @@ class TrainingBatchLoop(Loop): """Gets the number of active optimizers based on their frequency.""" return len(self.get_active_optimizers(batch_idx)) - def _run_optimization( - self, - batch_idx: int, - split_batch: Any, - ) -> Optional[ClosureResult]: - """Runs closure (train step + backward) together with optimization if necessary. - - Args: - batch_idx: the index of the current batch - split_batch: the current tbptt split of the whole batch - """ - # TODO: replace call through closure by direct call (manual optimization) - closure = self._make_closure(split_batch, batch_idx, self._hiddens) - closure() - result = closure.get_result() - - if result: - # if no result, user decided to skip optimization - # otherwise update running loss + reset accumulated loss - self._update_running_loss(result.loss) - - return result - - def _make_closure( - self, - split_batch: Any, - batch_idx: int, - hiddens: Any, - ) -> Closure: - """Build a closure object that captures the given arguments and runs the `training_step` function and - optionally other functions such as `backward` and `zero_grad`.""" - step_fn = self._make_step_fn(split_batch, batch_idx, hiddens) - backward_fn = None - zero_grad_fn = None - - return Closure( - step_fn=step_fn, - backward_fn=backward_fn, - zero_grad_fn=zero_grad_fn, - profiler=self.trainer.profiler, - ) - - def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]: - """Build the step function that runs the `training_step` and processes its output.""" - return partial(self._training_step, split_batch, batch_idx, hiddens) - - def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]: - """Performs the training step for manual optimization. - - Args: - split_batch: the current tbptt split of the current batch - batch_idx: the index of the current batch - hiddens: the model's hidden state of the previous iteration - - Returns: - an AttributeDict containing the training step output. - """ - # give the PL module a result for logging - model_ref = self.trainer.lightning_module - - with self.trainer.profiler.profile("model_forward"): - step_kwargs = _build_training_step_kwargs( - model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens - ) - - # manually capture logged metrics - model_ref._current_fx_name = "training_step" - with self.trainer.profiler.profile("training_step"): - training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() - - del step_kwargs - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - - _check_training_step_output(self.trainer.lightning_module, training_step_output) - - result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output) - if result_collection is None: - return - - return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection) - def _tbptt_split_batch(self, batch: Any) -> List[Any]: """Splits a single batch into a list of sequence steps for tbptt. diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 0010af32b4..579797edbb 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -56,8 +56,9 @@ def test_loops_state_dict_structure(): "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer": { "step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 365dbd5e82..56c2cae14f 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -464,6 +464,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch "current": {"ready": be_sch_steps, "completed": be_sch_steps}, }, "epoch_loop.batch_loop.state_dict": ANY, + "epoch_loop.batch_loop.manual_loop.state_dict": ANY, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer_idx": stop_optimizer, diff --git a/tests/models/test_truncated_bptt.py b/tests/models/test_truncated_bptt.py index ab10a527e3..d7a2dfb316 100644 --- a/tests/models/test_truncated_bptt.py +++ b/tests/models/test_truncated_bptt.py @@ -14,82 +14,97 @@ import pytest import torch +from torch.utils.data import DataLoader, Dataset from pytorch_lightning import Trainer from tests.helpers import BoringModel +class LinearModel(BoringModel): + """Linear model for testing TBPTT with automatic optimization.""" + + def __init__(self, truncated_bptt_steps=2, n_hidden_states=1, sequence_size=30, batch_size=30): + super().__init__() + self.truncated_bptt_steps = truncated_bptt_steps + self.n_hidden_states = n_hidden_states + self.sequence_size = sequence_size + self.batch_size = batch_size + self.automatic_optimization = True + + self.example_input_array = torch.randn(5, truncated_bptt_steps) + self.layer = torch.nn.Linear(in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) + self.test_hidden = None + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + if self.n_hidden_states == 1: + self.test_hidden = torch.rand(1) + else: + self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == self.truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(self.batch_size, self.truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(self.batch_size, self.truncated_bptt_steps)) + return {"loss": loss_val, "hiddens": self.test_hidden} + + def training_epoch_end(self, training_step_outputs): + training_step_outputs = training_step_outputs[0] + assert len(training_step_outputs) == (self.sequence_size / self.truncated_bptt_steps) + loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() + assert loss.grad_fn is None + self.log("train_loss", loss) + + +class ManualLinearModel(LinearModel): + """Linear model for testing TBPTT with manual optimization.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.automatic_optimization = False + + def training_step(self, batch, batch_idx, hiddens): + out = super().training_step(batch, batch_idx, hiddens) + loss, hiddens = out["loss"], out["hiddens"] + opt = self.optimizers() + opt.zero_grad() + self.manual_backward(loss) + opt.step() + assert loss.grad_fn is not None + return {"loss": loss, "hiddens": hiddens} + + +@pytest.mark.parametrize("model_class", (LinearModel, ManualLinearModel)) @pytest.mark.parametrize("n_hidden_states", (1, 2)) -def test_tbptt_cpu_model(tmpdir, n_hidden_states): - """Test truncated back propagation through time works.""" - truncated_bptt_steps = 2 +def test_tbptt_cpu_model_manual(tmpdir, n_hidden_states, model_class): + """Test truncated back propagation through time works with automatic and manual optimization.""" + sequence_size = 30 batch_size = 30 x_seq = torch.rand(batch_size, sequence_size, 1) y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() - class MockSeq2SeqDataset(torch.utils.data.Dataset): + class MockSeq2SeqDataset(Dataset): def __getitem__(self, i): return x_seq, y_seq_list def __len__(self): return 1 - class BpttTestModel(BoringModel): - def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs): - super().__init__(*args, **kwargs) - self.test_hidden = None - self.batch_size = batch_size - self.layer = torch.nn.Linear(in_features, out_features) - self.n_hidden_states = n_hidden_states - self.truncated_bptt_steps = truncated_bptt_steps - - def training_step(self, batch, batch_idx, hiddens): - assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - if self.n_hidden_states == 1: - self.test_hidden = torch.rand(1) - else: - self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states) - - x_tensor, y_list = batch - assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" - - y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) - assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" - - pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) - return {"loss": loss_val, "hiddens": self.test_hidden} - - def training_epoch_end(self, training_step_outputs): - training_step_outputs = training_step_outputs[0] - assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) - loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() - self.log("train_loss", loss) - - def train_dataloader(self): - return torch.utils.data.DataLoader( - dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None - ) - - model = BpttTestModel( - batch_size=batch_size, - in_features=truncated_bptt_steps, - out_features=truncated_bptt_steps, - n_hidden_states=n_hidden_states, - ) - model.example_input_array = torch.randn(5, truncated_bptt_steps) - - # fit model + train_dataloader = DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False) + model = model_class(n_hidden_states=n_hidden_states, sequence_size=sequence_size, batch_size=batch_size) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0, weights_summary=None, ) - trainer.fit(model) - assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}" + trainer.fit(model, train_dataloader) def test_tbptt_log(tmpdir): @@ -98,7 +113,7 @@ def test_tbptt_log(tmpdir): batch_size = 10 assert T % truncated_bptt_steps != 0, "Should test leftover time steps" - class MockSeq2SeqDataset(torch.utils.data.Dataset): + class MockSeq2SeqDataset(Dataset): def __init__(self): self.x_seq = torch.randn(N, T, F) self.y_seq = torch.randn(N, T, F) @@ -143,7 +158,7 @@ def test_tbptt_log(tmpdir): self.test_hidden = None def train_dataloader(self): - return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size) + return DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size) model = TestModel() model.training_epoch_end = None