Deprecate`truncated_bptt_steps` flag on Trainer in favor of same setting on the LightningModule (#7323)
* deprecate-tbptt-trainer * Update CHANGELOG.md * Update lightning.py * test * Update lightning.py * Update training_loop.py * Update training_loop.py * Update lightning.py * Update training_loop.py * Update training_loop.py * update docs * Update accelerator.py * Update accelerator.py * more docs * tweaks * chlog * comments Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
573a5a8a34
commit
98670c83a9
|
@ -157,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Changed
|
||||
|
||||
|
||||
- Changed `LightningModule.truncated_bptt_steps` to be property ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
|
||||
|
||||
|
||||
- Changed `EarlyStopping` callback from by default running `EarlyStopping.on_validation_end` if only training is run. Set `check_on_train_epoch_end` to run the callback at the end of the train epoch instead of at the end of the validation epoch ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069))
|
||||
|
||||
|
||||
|
@ -205,6 +208,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Deprecated
|
||||
|
||||
|
||||
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
|
||||
|
||||
|
||||
- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))
|
||||
|
||||
|
||||
|
|
|
@ -40,20 +40,31 @@ For example, it may save memory to use Truncated Backpropagation Through Time wh
|
|||
|
||||
Lightning can handle TBTT automatically via this flag.
|
||||
|
||||
.. testcode::
|
||||
.. testcode:: python
|
||||
|
||||
# DEFAULT (single backwards pass per batch)
|
||||
trainer = Trainer(truncated_bptt_steps=None)
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
# (split batch into sequences of size 2)
|
||||
trainer = Trainer(truncated_bptt_steps=2)
|
||||
class MyModel(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Important: This property activates truncated backpropagation through time
|
||||
# Setting this value to 2 splits the batch into sequences of size 2
|
||||
self.truncated_bptt_steps = 2
|
||||
|
||||
# Truncated back-propagation through time
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
# the training step must be updated to accept a ``hiddens`` argument
|
||||
# hiddens are the hiddens from the previous truncated backprop step
|
||||
out, hiddens = self.lstm(data, hiddens)
|
||||
return {
|
||||
"loss": ...,
|
||||
"hiddens": hiddens
|
||||
}
|
||||
|
||||
.. note:: If you need to modify how the batch is split,
|
||||
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|
||||
|
||||
.. note:: Using this feature requires updating your LightningModule's
|
||||
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
|
||||
|
||||
----------
|
||||
|
||||
Iterable Datasets
|
||||
|
|
|
@ -1005,6 +1005,63 @@ Get the model file size (in megabytes) using ``self.model_size`` inside Lightnin
|
|||
|
||||
--------------
|
||||
|
||||
truncated_bptt_steps
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Truncated back prop breaks performs backprop every k steps of
|
||||
a much longer sequence.
|
||||
|
||||
If this is enabled, your batches will automatically get truncated
|
||||
and the trainer will apply Truncated Backprop to it.
|
||||
|
||||
(`Williams et al. "An efficient gradient-based algorithm for on-line training of
|
||||
recurrent network trajectories."
|
||||
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
|
||||
|
||||
`Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_
|
||||
|
||||
.. testcode:: python
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
class MyModel(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Important: This property activates truncated backpropagation through time
|
||||
# Setting this value to 2 splits the batch into sequences of size 2
|
||||
self.truncated_bptt_steps = 2
|
||||
|
||||
# Truncated back-propagation through time
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
# the training step must be updated to accept a ``hiddens`` argument
|
||||
# hiddens are the hiddens from the previous truncated backprop step
|
||||
out, hiddens = self.lstm(data, hiddens)
|
||||
return {
|
||||
"loss": ...,
|
||||
"hiddens": hiddens
|
||||
}
|
||||
|
||||
Lightning takes care to split your batch along the time-dimension.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# we use the second as the time dimension
|
||||
# (batch, time, ...)
|
||||
sub_batch = batch[0, 0:t, ...]
|
||||
|
||||
To modify how the batch is split,
|
||||
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
|
||||
|
||||
.. testcode:: python
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
# do your own splitting on the batch
|
||||
return splits
|
||||
|
||||
--------------
|
||||
|
||||
Hooks
|
||||
^^^^^
|
||||
This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``.
|
||||
|
|
|
@ -196,8 +196,7 @@ class Accelerator:
|
|||
- batch_idx (int): Integer displaying index of this batch
|
||||
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
||||
- hiddens(:class:`~torch.Tensor`): Passed in if
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
|
||||
|
||||
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
||||
"""
|
||||
args[0] = self.to_device(args[0])
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ class LightningModule(
|
|||
Module,
|
||||
):
|
||||
# Below is for property support of JIT in PyTorch 1.7
|
||||
# since none of them is important when using JIT, we are going to ignore them.
|
||||
# since none of these are important when using JIT, we are going to ignore them.
|
||||
__jit_unused_properties__ = [
|
||||
"datamodule",
|
||||
"example_input_array",
|
||||
|
@ -72,6 +72,8 @@ class LightningModule(
|
|||
"local_rank",
|
||||
"logger",
|
||||
"model_size",
|
||||
"automatic_optimization",
|
||||
"truncated_bptt_steps",
|
||||
] + DeviceDtypeModuleMixin.__jit_unused_properties__
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
@ -104,6 +106,7 @@ class LightningModule(
|
|||
self._current_hook_fx_name: Optional[str] = None
|
||||
self._current_dataloader_idx: Optional[int] = None
|
||||
self._automatic_optimization: bool = True
|
||||
self._truncated_bptt_steps: int = 0
|
||||
self._param_requires_grad_state = dict()
|
||||
|
||||
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
|
||||
|
@ -191,6 +194,18 @@ class LightningModule(
|
|||
def automatic_optimization(self, automatic_optimization: bool) -> None:
|
||||
self._automatic_optimization = automatic_optimization
|
||||
|
||||
@property
|
||||
def truncated_bptt_steps(self) -> int:
|
||||
"""
|
||||
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much a longer sequence.
|
||||
If this is > 0, the training step is passed ``hiddens``.
|
||||
"""
|
||||
return self._truncated_bptt_steps
|
||||
|
||||
@truncated_bptt_steps.setter
|
||||
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
|
||||
self._truncated_bptt_steps = truncated_bptt_steps
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
""" Reference to the logger object in the Trainer. """
|
||||
|
@ -524,7 +539,7 @@ class LightningModule(
|
|||
batch_idx (int): Integer displaying index of this batch
|
||||
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
||||
hiddens(:class:`~torch.Tensor`): Passed in if
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
|
||||
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
||||
|
||||
Return:
|
||||
Any of.
|
||||
|
@ -1469,7 +1484,7 @@ class LightningModule(
|
|||
Note:
|
||||
Called in the training loop after
|
||||
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
|
||||
if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0.
|
||||
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
||||
Each returned batch split is passed separately to :meth:`training_step`.
|
||||
|
||||
"""
|
||||
|
@ -1570,7 +1585,9 @@ class LightningModule(
|
|||
if avg_training_loss is not None:
|
||||
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
|
||||
|
||||
if self.trainer.truncated_bptt_steps is not None:
|
||||
module_tbptt_enabled = self.truncated_bptt_steps > 0
|
||||
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
|
||||
if module_tbptt_enabled or trainer_tbptt_enabled:
|
||||
tqdm_dict["split_idx"] = self.trainer.split_idx
|
||||
|
||||
if self.trainer.logger is not None and self.trainer.logger.version is not None:
|
||||
|
|
|
@ -11,8 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -23,12 +26,12 @@ class TrainingTricksConnector:
|
|||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
gradient_clip_val,
|
||||
gradient_clip_algorithm,
|
||||
track_grad_norm,
|
||||
accumulate_grad_batches,
|
||||
truncated_bptt_steps,
|
||||
terminate_on_nan,
|
||||
gradient_clip_val: float,
|
||||
gradient_clip_algorithm: str,
|
||||
track_grad_norm: Union[int, float, str],
|
||||
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
|
||||
truncated_bptt_steps: Optional[int],
|
||||
terminate_on_nan: bool,
|
||||
):
|
||||
|
||||
self.trainer.terminate_on_nan = terminate_on_nan
|
||||
|
@ -48,6 +51,11 @@ class TrainingTricksConnector:
|
|||
self.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
self.configure_accumulated_gradients(accumulate_grad_batches)
|
||||
|
||||
if truncated_bptt_steps is not None and truncated_bptt_steps > 0:
|
||||
rank_zero_deprecation(
|
||||
"Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5."
|
||||
" Set truncated_bptt_steps directly on the LightningModule instead."
|
||||
)
|
||||
self.trainer.truncated_bptt_steps = truncated_bptt_steps
|
||||
|
||||
def configure_accumulated_gradients(self, accumulate_grad_batches):
|
||||
|
|
|
@ -280,8 +280,8 @@ class Trainer(
|
|||
|
||||
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
|
||||
|
||||
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
|
||||
sequence.
|
||||
truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5.
|
||||
Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead.
|
||||
|
||||
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
|
||||
use int to check every n steps (batches).
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
from contextlib import contextmanager, suppress
|
||||
from copy import copy, deepcopy
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -441,12 +441,13 @@ class TrainLoop:
|
|||
grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm)
|
||||
return grad_norm_dict
|
||||
|
||||
def tbptt_split_batch(self, batch):
|
||||
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
|
||||
splits = [batch]
|
||||
if self.trainer.truncated_bptt_steps is not None:
|
||||
truncated_bptt_enabled = self._truncated_bptt_enabled()
|
||||
if truncated_bptt_enabled:
|
||||
model_ref = self.trainer.lightning_module
|
||||
with self.trainer.profiler.profile("tbptt_split_batch"):
|
||||
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
|
||||
splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps())
|
||||
return splits
|
||||
|
||||
def run_training_epoch(self):
|
||||
|
@ -626,7 +627,7 @@ class TrainLoop:
|
|||
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
|
||||
|
||||
# lightning module hook
|
||||
splits = self.tbptt_split_batch(batch)
|
||||
splits = self._tbptt_split_batch(batch)
|
||||
|
||||
for split_idx, split_batch in enumerate(splits):
|
||||
|
||||
|
@ -896,11 +897,22 @@ class TrainLoop:
|
|||
)
|
||||
|
||||
# pass hiddens if using tbptt
|
||||
if self.trainer.truncated_bptt_steps is not None:
|
||||
if self._truncated_bptt_enabled():
|
||||
args.append(hiddens)
|
||||
|
||||
return args
|
||||
|
||||
def _truncated_bptt_enabled(self) -> bool:
|
||||
""" Temporary tbptt utilities until this flag is fully migrated to the lightning module. """
|
||||
return self._truncated_bptt_steps() > 0
|
||||
|
||||
def _truncated_bptt_steps(self) -> int:
|
||||
lightning_module = self.trainer.lightning_module
|
||||
# Give precedence to the LightningModule as the Trainer flag will be removed in v1.5
|
||||
if lightning_module.truncated_bptt_steps > 0:
|
||||
return lightning_module.truncated_bptt_steps
|
||||
return self.trainer.truncated_bptt_steps or 0
|
||||
|
||||
def save_loggers_on_train_batch_end(self):
|
||||
# when loggers should save to disk
|
||||
should_flush_logs = self.trainer.logger_connector.should_flush_logs
|
||||
|
|
|
@ -412,3 +412,8 @@ def test_v1_5_0_datamodule_setter():
|
|||
model.datamodule = datamodule
|
||||
with pytest.deprecated_call(match="The `LightningModule.datamodule`"):
|
||||
_ = model.datamodule
|
||||
|
||||
|
||||
def test_v1_5_0_trainer_tbptt_steps(tmpdir):
|
||||
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
|
||||
_ = Trainer(truncated_bptt_steps=1)
|
||||
|
|
Loading…
Reference in New Issue