diff --git a/CHANGELOG.md b/CHANGELOG.md index 5175190ce2..cfc6dec9c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -157,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed `LightningModule.truncated_bptt_steps` to be property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) + + - Changed `EarlyStopping` callback from by default running `EarlyStopping.on_validation_end` if only training is run. Set `check_on_train_epoch_end` to run the callback at the end of the train epoch instead of at the end of the validation epoch ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069)) @@ -205,6 +208,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) + + - Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292)) diff --git a/docs/source/advanced/sequences.rst b/docs/source/advanced/sequences.rst index 759a671cc4..f010372d96 100644 --- a/docs/source/advanced/sequences.rst +++ b/docs/source/advanced/sequences.rst @@ -40,20 +40,31 @@ For example, it may save memory to use Truncated Backpropagation Through Time wh Lightning can handle TBTT automatically via this flag. -.. testcode:: +.. testcode:: python - # DEFAULT (single backwards pass per batch) - trainer = Trainer(truncated_bptt_steps=None) + from pytorch_lightning import LightningModule - # (split batch into sequences of size 2) - trainer = Trainer(truncated_bptt_steps=2) + class MyModel(LightningModule): + + def __init__(self): + super().__init__() + # Important: This property activates truncated backpropagation through time + # Setting this value to 2 splits the batch into sequences of size 2 + self.truncated_bptt_steps = 2 + + # Truncated back-propagation through time + def training_step(self, batch, batch_idx, hiddens): + # the training step must be updated to accept a ``hiddens`` argument + # hiddens are the hiddens from the previous truncated backprop step + out, hiddens = self.lstm(data, hiddens) + return { + "loss": ..., + "hiddens": hiddens + } .. note:: If you need to modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. -.. note:: Using this feature requires updating your LightningModule's - :meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg. - ---------- Iterable Datasets diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 64aed36e02..3865400121 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1005,6 +1005,63 @@ Get the model file size (in megabytes) using ``self.model_size`` inside Lightnin -------------- +truncated_bptt_steps +^^^^^^^^^^^^^^^^^^^^ + +Truncated back prop breaks performs backprop every k steps of +a much longer sequence. + +If this is enabled, your batches will automatically get truncated +and the trainer will apply Truncated Backprop to it. + +(`Williams et al. "An efficient gradient-based algorithm for on-line training of +recurrent network trajectories." +`_) + +`Tutorial `_ + +.. testcode:: python + + from pytorch_lightning import LightningModule + + class MyModel(LightningModule): + + def __init__(self): + super().__init__() + # Important: This property activates truncated backpropagation through time + # Setting this value to 2 splits the batch into sequences of size 2 + self.truncated_bptt_steps = 2 + + # Truncated back-propagation through time + def training_step(self, batch, batch_idx, hiddens): + # the training step must be updated to accept a ``hiddens`` argument + # hiddens are the hiddens from the previous truncated backprop step + out, hiddens = self.lstm(data, hiddens) + return { + "loss": ..., + "hiddens": hiddens + } + +Lightning takes care to split your batch along the time-dimension. + +.. code-block:: python + + # we use the second as the time dimension + # (batch, time, ...) + sub_batch = batch[0, 0:t, ...] + +To modify how the batch is split, +override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: + +.. testcode:: python + + class LitMNIST(LightningModule): + def tbptt_split_batch(self, batch, split_size): + # do your own splitting on the batch + return splits + +-------------- + Hooks ^^^^^ This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``. diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index bb6981ffbd..558fbc30d5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -196,8 +196,7 @@ class Accelerator: - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. - + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. """ args[0] = self.to_device(args[0]) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ecab75ddb7..79cf6978fa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -59,7 +59,7 @@ class LightningModule( Module, ): # Below is for property support of JIT in PyTorch 1.7 - # since none of them is important when using JIT, we are going to ignore them. + # since none of these are important when using JIT, we are going to ignore them. __jit_unused_properties__ = [ "datamodule", "example_input_array", @@ -72,6 +72,8 @@ class LightningModule( "local_rank", "logger", "model_size", + "automatic_optimization", + "truncated_bptt_steps", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -104,6 +106,7 @@ class LightningModule( self._current_hook_fx_name: Optional[str] = None self._current_dataloader_idx: Optional[int] = None self._automatic_optimization: bool = True + self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: @@ -191,6 +194,18 @@ class LightningModule( def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization + @property + def truncated_bptt_steps(self) -> int: + """ + truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence. + If this is > 0, the training step is passed ``hiddens``. + """ + return self._truncated_bptt_steps + + @truncated_bptt_steps.setter + def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None: + self._truncated_bptt_steps = truncated_bptt_steps + @property def logger(self): """ Reference to the logger object in the Trainer. """ @@ -524,7 +539,7 @@ class LightningModule( batch_idx (int): Integer displaying index of this batch optimizer_idx (int): When using multiple optimizers, this argument will also be present. hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. + :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Return: Any of. @@ -1469,7 +1484,7 @@ class LightningModule( Note: Called in the training loop after :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` - if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0. + if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Each returned batch split is passed separately to :meth:`training_step`. """ @@ -1570,7 +1585,9 @@ class LightningModule( if avg_training_loss is not None: tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - if self.trainer.truncated_bptt_steps is not None: + module_tbptt_enabled = self.truncated_bptt_steps > 0 + trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0 + if module_tbptt_enabled or trainer_tbptt_enabled: tqdm_dict["split_idx"] = self.trainer.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 4c5a036c74..f27288d2b1 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List, Optional, Union + from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.distributed import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -23,12 +26,12 @@ class TrainingTricksConnector: def on_trainer_init( self, - gradient_clip_val, - gradient_clip_algorithm, - track_grad_norm, - accumulate_grad_batches, - truncated_bptt_steps, - terminate_on_nan, + gradient_clip_val: float, + gradient_clip_algorithm: str, + track_grad_norm: Union[int, float, str], + accumulate_grad_batches: Union[int, Dict[int, int], List[list]], + truncated_bptt_steps: Optional[int], + terminate_on_nan: bool, ): self.trainer.terminate_on_nan = terminate_on_nan @@ -48,6 +51,11 @@ class TrainingTricksConnector: self.trainer.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) + if truncated_bptt_steps is not None and truncated_bptt_steps > 0: + rank_zero_deprecation( + "Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5." + " Set truncated_bptt_steps directly on the LightningModule instead." + ) self.trainer.truncated_bptt_steps = truncated_bptt_steps def configure_accumulated_gradients(self, accumulate_grad_batches): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 39085ff6ef..105f8e6810 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -280,8 +280,8 @@ class Trainer( track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. - truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer - sequence. + truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5. + Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead. val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8c510f08a8..a23c8ba28c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,7 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -441,12 +441,13 @@ class TrainLoop: grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) return grad_norm_dict - def tbptt_split_batch(self, batch): + def _tbptt_split_batch(self, batch: Any) -> List[Any]: splits = [batch] - if self.trainer.truncated_bptt_steps is not None: + truncated_bptt_enabled = self._truncated_bptt_enabled() + if truncated_bptt_enabled: model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("tbptt_split_batch"): - splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) + splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps()) return splits def run_training_epoch(self): @@ -626,7 +627,7 @@ class TrainLoop: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # lightning module hook - splits = self.tbptt_split_batch(batch) + splits = self._tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): @@ -896,11 +897,22 @@ class TrainLoop: ) # pass hiddens if using tbptt - if self.trainer.truncated_bptt_steps is not None: + if self._truncated_bptt_enabled(): args.append(hiddens) return args + def _truncated_bptt_enabled(self) -> bool: + """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ + return self._truncated_bptt_steps() > 0 + + def _truncated_bptt_steps(self) -> int: + lightning_module = self.trainer.lightning_module + # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 + if lightning_module.truncated_bptt_steps > 0: + return lightning_module.truncated_bptt_steps + return self.trainer.truncated_bptt_steps or 0 + def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index ec49e859ba..0a8f1c7955 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -412,3 +412,8 @@ def test_v1_5_0_datamodule_setter(): model.datamodule = datamodule with pytest.deprecated_call(match="The `LightningModule.datamodule`"): _ = model.datamodule + + +def test_v1_5_0_trainer_tbptt_steps(tmpdir): + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + _ = Trainer(truncated_bptt_steps=1)