2020-05-05 02:16:54 +00:00
|
|
|
|
2020-02-11 04:55:22 +00:00
|
|
|
Sequential Data
|
2021-10-18 09:43:11 +00:00
|
|
|
===============
|
2020-06-19 06:38:10 +00:00
|
|
|
|
2020-02-11 04:55:22 +00:00
|
|
|
Truncated Backpropagation Through Time
|
2020-06-17 21:44:11 +00:00
|
|
|
--------------------------------------
|
2020-02-11 04:55:22 +00:00
|
|
|
There are times when multiple backwards passes are needed for each batch.
|
|
|
|
For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs.
|
|
|
|
|
|
|
|
Lightning can handle TBTT automatically via this flag.
|
|
|
|
|
2021-05-13 20:33:12 +00:00
|
|
|
.. testcode::
|
2021-05-05 10:21:00 +00:00
|
|
|
|
|
|
|
from pytorch_lightning import LightningModule
|
2020-02-11 04:55:22 +00:00
|
|
|
|
|
|
|
|
2021-07-28 16:08:31 +00:00
|
|
|
class MyModel(LightningModule):
|
2021-05-05 10:21:00 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
# Important: This property activates truncated backpropagation through time
|
|
|
|
# Setting this value to 2 splits the batch into sequences of size 2
|
|
|
|
self.truncated_bptt_steps = 2
|
|
|
|
|
|
|
|
# Truncated back-propagation through time
|
|
|
|
def training_step(self, batch, batch_idx, hiddens):
|
|
|
|
# the training step must be updated to accept a ``hiddens`` argument
|
|
|
|
# hiddens are the hiddens from the previous truncated backprop step
|
|
|
|
out, hiddens = self.lstm(data, hiddens)
|
2021-07-28 16:08:31 +00:00
|
|
|
return {"loss": ..., "hiddens": hiddens}
|
2020-02-11 04:55:22 +00:00
|
|
|
|
|
|
|
.. note:: If you need to modify how the batch is split,
|
|
|
|
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|