Deprecate`truncated_bptt_steps` flag on Trainer in favor of same setting on the LightningModule (#7323)

* deprecate-tbptt-trainer

* Update CHANGELOG.md

* Update lightning.py

* test

* Update lightning.py

* Update training_loop.py

* Update training_loop.py

* Update lightning.py

* Update training_loop.py

* Update training_loop.py

* update docs

* Update accelerator.py

* Update accelerator.py

* more docs

* tweaks

* chlog

* comments

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
ananthsub 2021-05-05 03:21:00 -07:00 committed by GitHub
parent 573a5a8a34
commit 98670c83a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 143 additions and 28 deletions

View File

@ -157,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
- Changed `LightningModule.truncated_bptt_steps` to be property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
- Changed `EarlyStopping` callback from by default running `EarlyStopping.on_validation_end` if only training is run. Set `check_on_train_epoch_end` to run the callback at the end of the train epoch instead of at the end of the validation epoch ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069))
@ -205,6 +208,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))

View File

@ -40,20 +40,31 @@ For example, it may save memory to use Truncated Backpropagation Through Time wh
Lightning can handle TBTT automatically via this flag.
.. testcode::
.. testcode:: python
# DEFAULT (single backwards pass per batch)
trainer = Trainer(truncated_bptt_steps=None)
from pytorch_lightning import LightningModule
# (split batch into sequences of size 2)
trainer = Trainer(truncated_bptt_steps=2)
class MyModel(LightningModule):
def __init__(self):
super().__init__()
# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# the training step must be updated to accept a ``hiddens`` argument
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
return {
"loss": ...,
"hiddens": hiddens
}
.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
.. note:: Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
----------
Iterable Datasets

View File

@ -1005,6 +1005,63 @@ Get the model file size (in megabytes) using ``self.model_size`` inside Lightnin
--------------
truncated_bptt_steps
^^^^^^^^^^^^^^^^^^^^
Truncated back prop breaks performs backprop every k steps of
a much longer sequence.
If this is enabled, your batches will automatically get truncated
and the trainer will apply Truncated Backprop to it.
(`Williams et al. "An efficient gradient-based algorithm for on-line training of
recurrent network trajectories."
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
`Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_
.. testcode:: python
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
def __init__(self):
super().__init__()
# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# the training step must be updated to accept a ``hiddens`` argument
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
return {
"loss": ...,
"hiddens": hiddens
}
Lightning takes care to split your batch along the time-dimension.
.. code-block:: python
# we use the second as the time dimension
# (batch, time, ...)
sub_batch = batch[0, 0:t, ...]
To modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
.. testcode:: python
class LitMNIST(LightningModule):
def tbptt_split_batch(self, batch, split_size):
# do your own splitting on the batch
return splits
--------------
Hooks
^^^^^
This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``.

View File

@ -196,8 +196,7 @@ class Accelerator:
- batch_idx (int): Integer displaying index of this batch
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
"""
args[0] = self.to_device(args[0])

View File

@ -59,7 +59,7 @@ class LightningModule(
Module,
):
# Below is for property support of JIT in PyTorch 1.7
# since none of them is important when using JIT, we are going to ignore them.
# since none of these are important when using JIT, we are going to ignore them.
__jit_unused_properties__ = [
"datamodule",
"example_input_array",
@ -72,6 +72,8 @@ class LightningModule(
"local_rank",
"logger",
"model_size",
"automatic_optimization",
"truncated_bptt_steps",
] + DeviceDtypeModuleMixin.__jit_unused_properties__
def __init__(self, *args: Any, **kwargs: Any) -> None:
@ -104,6 +106,7 @@ class LightningModule(
self._current_hook_fx_name: Optional[str] = None
self._current_dataloader_idx: Optional[int] = None
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = dict()
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
@ -191,6 +194,18 @@ class LightningModule(
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization
@property
def truncated_bptt_steps(self) -> int:
"""
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence.
If this is > 0, the training step is passed ``hiddens``.
"""
return self._truncated_bptt_steps
@truncated_bptt_steps.setter
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
self._truncated_bptt_steps = truncated_bptt_steps
@property
def logger(self):
""" Reference to the logger object in the Trainer. """
@ -524,7 +539,7 @@ class LightningModule(
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
Return:
Any of.
@ -1469,7 +1484,7 @@ class LightningModule(
Note:
Called in the training loop after
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0.
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
Each returned batch split is passed separately to :meth:`training_step`.
"""
@ -1570,7 +1585,9 @@ class LightningModule(
if avg_training_loss is not None:
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
if self.trainer.truncated_bptt_steps is not None:
module_tbptt_enabled = self.truncated_bptt_steps > 0
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
if module_tbptt_enabled or trainer_tbptt_enabled:
tqdm_dict["split_idx"] = self.trainer.split_idx
if self.trainer.logger is not None and self.trainer.logger.version is not None:

View File

@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -23,12 +26,12 @@ class TrainingTricksConnector:
def on_trainer_init(
self,
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan,
gradient_clip_val: float,
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
truncated_bptt_steps: Optional[int],
terminate_on_nan: bool,
):
self.trainer.terminate_on_nan = terminate_on_nan
@ -48,6 +51,11 @@ class TrainingTricksConnector:
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)
if truncated_bptt_steps is not None and truncated_bptt_steps > 0:
rank_zero_deprecation(
"Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5."
" Set truncated_bptt_steps directly on the LightningModule instead."
)
self.trainer.truncated_bptt_steps = truncated_bptt_steps
def configure_accumulated_gradients(self, accumulate_grad_batches):

View File

@ -280,8 +280,8 @@ class Trainer(
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
sequence.
truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5.
Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead.
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
use int to check every n steps (batches).

View File

@ -14,7 +14,7 @@
from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
@ -441,12 +441,13 @@ class TrainLoop:
grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm)
return grad_norm_dict
def tbptt_split_batch(self, batch):
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
splits = [batch]
if self.trainer.truncated_bptt_steps is not None:
truncated_bptt_enabled = self._truncated_bptt_enabled()
if truncated_bptt_enabled:
model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("tbptt_split_batch"):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps())
return splits
def run_training_epoch(self):
@ -626,7 +627,7 @@ class TrainLoop:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
# lightning module hook
splits = self.tbptt_split_batch(batch)
splits = self._tbptt_split_batch(batch)
for split_idx, split_batch in enumerate(splits):
@ -896,11 +897,22 @@ class TrainLoop:
)
# pass hiddens if using tbptt
if self.trainer.truncated_bptt_steps is not None:
if self._truncated_bptt_enabled():
args.append(hiddens)
return args
def _truncated_bptt_enabled(self) -> bool:
""" Temporary tbptt utilities until this flag is fully migrated to the lightning module. """
return self._truncated_bptt_steps() > 0
def _truncated_bptt_steps(self) -> int:
lightning_module = self.trainer.lightning_module
# Give precedence to the LightningModule as the Trainer flag will be removed in v1.5
if lightning_module.truncated_bptt_steps > 0:
return lightning_module.truncated_bptt_steps
return self.trainer.truncated_bptt_steps or 0
def save_loggers_on_train_batch_end(self):
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs

View File

@ -412,3 +412,8 @@ def test_v1_5_0_datamodule_setter():
model.datamodule = datamodule
with pytest.deprecated_call(match="The `LightningModule.datamodule`"):
_ = model.datamodule
def test_v1_5_0_trainer_tbptt_steps(tmpdir):
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
_ = Trainer(truncated_bptt_steps=1)