Update hooks pseudocode (#7713)

This commit is contained in:
Carlos Mocholí 2021-05-27 12:27:26 +02:00 committed by GitHub
parent 04dcb1786d
commit 906c067b07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 31 deletions

View File

@ -74,7 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

View File

@ -1064,7 +1064,9 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
Hooks
^^^^^
This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``.
This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`.
The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference
for more information.
.. code-block:: python
@ -1075,36 +1077,41 @@ This is the pseudocode to describe how all the hooks are called during a call to
configure_callbacks()
on_fit_start()
for gpu/tpu in gpu/tpus:
train_on_device(model.copy())
on_fit_end()
with parallel(devices):
# devices can be GPUs, TPUs, ...
train_on_device(model)
def train_on_device(model):
# setup is called PER DEVICE
setup()
# called PER DEVICE
on_fit_start()
setup('fit')
configure_optimizers()
on_pretrain_routine_start()
on_pretrain_routine_start()
on_pretrain_routine_end()
# the sanity check runs here
on_train_start()
for epoch in epochs:
train_loop()
on_train_end()
teardown()
on_fit_end()
teardown('fit')
def train_loop():
on_epoch_start()
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
for batch in train_dataloader():
on_train_batch_start()
# ----- train_step methods -------
out = training_step(batch)
train_outs.append(out)
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
loss = out.loss
training_step()
on_before_zero_grad()
optimizer_zero_grad()
@ -1114,38 +1121,42 @@ This is the pseudocode to describe how all the hooks are called during a call to
optimizer_step()
on_train_batch_end(out)
on_train_batch_end()
if should_check_val:
val_loop()
# end training epoch
training_epoch_end(outs)
on_train_epoch_end(outs)
training_epoch_end()
on_train_epoch_end()
on_epoch_end()
def val_loop():
model.eval()
on_validation_model_eval() # calls `model.eval()`
torch.set_grad_enabled(False)
on_validation_start()
on_epoch_start()
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
for batch in val_dataloader():
on_validation_batch_start()
# -------- val step methods -------
out = validation_step(val_batch)
val_outs.append(out)
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
on_validation_batch_end(out)
validation_step()
on_validation_batch_end()
validation_epoch_end()
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()
on_validation_end()
# set up for train
model.train()
on_validation_model_train() # calls `model.train()`
torch.set_grad_enabled(True)
backward

View File

@ -255,6 +255,18 @@ class HookedModel(BoringModel):
'on_validation_batch_end',
]
def prepare_data(self):
self.called.append("prepare_data")
return super().prepare_data()
def configure_callbacks(self):
self.called.append("configure_callbacks")
return super().configure_callbacks()
def configure_optimizers(self):
self.called.append("configure_optimizers")
return super().configure_optimizers()
def training_step(self, *args, **kwargs):
self.called.append("training_step")
return super().training_step(*args, **kwargs)
@ -451,7 +463,10 @@ def test_trainer_model_hook_system_fit(tmpdir):
assert model.called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup_fit',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
@ -504,7 +519,10 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
assert model.called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup_fit',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
@ -535,6 +553,8 @@ def test_trainer_model_hook_system_validate(tmpdir):
assert model.called == []
trainer.validate(model, verbose=False)
expected = [
'prepare_data',
'configure_callbacks',
'setup_validate',
'on_validation_model_eval',
'on_validation_start',
@ -567,6 +587,8 @@ def test_trainer_model_hook_system_test(tmpdir):
assert model.called == []
trainer.test(model, verbose=False)
expected = [
'prepare_data',
'configure_callbacks',
'setup_test',
'on_test_model_eval',
'on_test_start',