ref: inner train loop (intermediate step) 16/n (#3375)
* ref: inner train loop (intermediate step) 16/n * ref: inner train loop (intermediate step) 16/n * ref: inner train loop (intermediate step) 16/n * ref: inner train loop (intermediate step) 16/n * ref: inner train loop (intermediate step) 16/n * ref: inner train loop (intermediate step) 16/n
This commit is contained in:
parent
bce5c81f3a
commit
69e3f904df
|
@ -12,156 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
The lightning training loop handles everything except the actual computations of your model.
|
||||
To decide what will happen in your training loop, define the `training_step` function.
|
||||
|
||||
Below are all the things lightning automates for you in the training loop.
|
||||
|
||||
Accumulated gradients
|
||||
---------------------
|
||||
|
||||
Accumulated gradients runs K small batches of size N before doing a backwards pass.
|
||||
The effect is a large effective batch size of size KxN.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT (ie: no accumulated grads)
|
||||
trainer = Trainer(accumulate_grad_batches=1)
|
||||
|
||||
Force training for min or max epochs
|
||||
------------------------------------
|
||||
|
||||
It can be useful to force training for a minimum number of epochs or limit to a max number
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(min_epochs=1, max_epochs=1000)
|
||||
|
||||
Force disable early stop
|
||||
------------------------
|
||||
|
||||
To disable early stopping pass None to the early_stop_callback
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(early_stop_callback=None)
|
||||
|
||||
Gradient Clipping
|
||||
-----------------
|
||||
|
||||
Gradient clipping may be enabled to avoid exploding gradients.
|
||||
Specifically, this will `clip the gradient norm computed over all model parameters
|
||||
`together <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT (ie: don't clip)
|
||||
trainer = Trainer(gradient_clip_val=0)
|
||||
|
||||
# clip gradients with norm above 0.5
|
||||
trainer = Trainer(gradient_clip_val=0.5)
|
||||
|
||||
Inspect gradient norms
|
||||
----------------------
|
||||
|
||||
Looking at grad norms can help you figure out where training might be going wrong.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT (-1 doesn't track norms)
|
||||
trainer = Trainer(track_grad_norm=-1)
|
||||
|
||||
# track the LP norm (P=2 here)
|
||||
trainer = Trainer(track_grad_norm=2)
|
||||
|
||||
Set how much of the training set to check
|
||||
-----------------------------------------
|
||||
|
||||
If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag.
|
||||
|
||||
limit_train_batches will be overwritten by overfit_batches if `overfit_batches > 0`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(limit_train_batches=1.0)
|
||||
|
||||
# check 10% only
|
||||
trainer = Trainer(limit_train_batches=0.1)
|
||||
|
||||
# check 10 batches only
|
||||
trainer = Trainer(limit_train_batches=10)
|
||||
|
||||
Packed sequences as inputs
|
||||
--------------------------
|
||||
|
||||
When using PackedSequence, do 2 things:
|
||||
1. return either a padded tensor in dataset or a list of variable length tensors
|
||||
in the dataloader collate_fn (example above shows the list implementation).
|
||||
2. Pack the sequence in forward or training and validation steps depending on use case.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# For use in dataloader
|
||||
def collate_fn(batch):
|
||||
x = [item[0] for item in batch]
|
||||
y = [item[1] for item in batch]
|
||||
return x, y
|
||||
|
||||
# In module
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = rnn.pack_sequence(batch[0], enforce_sorted=False)
|
||||
y = rnn.pack_sequence(batch[1], enforce_sorted=False)
|
||||
|
||||
|
||||
Truncated Backpropagation Through Time
|
||||
--------------------------------------
|
||||
|
||||
There are times when multiple backwards passes are needed for each batch.
|
||||
For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs.
|
||||
|
||||
When this flag is enabled each batch is split into sequences of size truncated_bptt_steps
|
||||
and passed to training_step(...) separately. A default splitting function is provided,
|
||||
however, you can override it for more flexibility. See `tbptt_split_batch`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT (single backwards pass per batch)
|
||||
trainer = Trainer(truncated_bptt_steps=None)
|
||||
|
||||
# (split batch into sequences of size 2)
|
||||
trainer = Trainer(truncated_bptt_steps=2)
|
||||
|
||||
|
||||
NaN detection and intervention
|
||||
------------------------------
|
||||
When the `terminate_on_nan` flag is enabled, after every forward pass during training, Lightning will
|
||||
check that
|
||||
|
||||
1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
|
||||
2. the model parameters have finite values.
|
||||
|
||||
Lightning will terminate the training loop with an error message if NaN or infinite
|
||||
values are detected. If this happens, you should investigate numerically unstable operations
|
||||
in your model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT (won't perform the NaN check)
|
||||
trainer = Trainer(terminate_on_nan=False)
|
||||
|
||||
# (NaN check each batch and terminate on NaN or infinite values)
|
||||
trainer = Trainer(terminate_on_nan=True)
|
||||
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
from typing import Union, List
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
|
@ -171,12 +24,9 @@ from pytorch_lightning.utilities.parsing import AttributeDict
|
|||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
|
||||
from pytorch_lightning.trainer.data_connector import DataConnector
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
|
||||
|
||||
class TrainerTrainLoopMixin(ABC):
|
||||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
on_gpu: bool
|
||||
use_horovod: bool
|
||||
check_val_every_n_epoch: ...
|
||||
|
|
|
@ -47,7 +47,8 @@ def test_training_step_dict(tmpdir):
|
|||
assert pbar_metrics['pbar_acc2'] == 19.0
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||
|
||||
|
||||
|
|
|
@ -84,7 +84,8 @@ def test_training_step_result_log_step_only(tmpdir):
|
|||
assert f'step_log_acc2_b{batch_idx}' in train_step_out
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||
|
||||
|
||||
|
@ -158,7 +159,8 @@ def test_training_step_result_log_epoch_only(tmpdir):
|
|||
assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||
|
||||
|
||||
|
@ -293,7 +295,8 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
assert 'epoch_step_epoch_log_acc2' in train_step_out
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||
|
||||
|
||||
|
@ -372,7 +375,8 @@ def test_training_step_epoch_end_result(tmpdir):
|
|||
assert 'epoch_step_epoch_log_acc2' in train_step_out
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||
|
||||
|
||||
|
|
|
@ -43,7 +43,8 @@ def test_training_step_scalar(tmpdir):
|
|||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
||||
|
@ -80,7 +81,8 @@ def training_step_scalar_with_step_end(tmpdir):
|
|||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
||||
|
@ -127,7 +129,8 @@ def test_full_training_loop_scalar(tmpdir):
|
|||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
||||
|
||||
|
@ -170,5 +173,6 @@ def test_train_step_epoch_end_scalar(tmpdir):
|
|||
assert train_step_out.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
opt_closure_result = trainer.train_loop.training_step_and_backward(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||
assert opt_closure_result['loss'].item() == 171
|
||||
|
|
Loading…
Reference in New Issue