Update hooks pseudocode (#7713)
This commit is contained in:
parent
04dcb1786d
commit
906c067b07
|
@ -74,7 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
|
||||
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
|
||||
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))
|
||||
|
||||
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
|
||||
* Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
|
||||
|
||||
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))
|
||||
|
||||
|
|
|
@ -1064,7 +1064,9 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
|
|||
|
||||
Hooks
|
||||
^^^^^
|
||||
This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``.
|
||||
This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`.
|
||||
The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference
|
||||
for more information.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -1075,36 +1077,41 @@ This is the pseudocode to describe how all the hooks are called during a call to
|
|||
|
||||
configure_callbacks()
|
||||
|
||||
on_fit_start()
|
||||
|
||||
for gpu/tpu in gpu/tpus:
|
||||
train_on_device(model.copy())
|
||||
|
||||
on_fit_end()
|
||||
with parallel(devices):
|
||||
# devices can be GPUs, TPUs, ...
|
||||
train_on_device(model)
|
||||
|
||||
def train_on_device(model):
|
||||
# setup is called PER DEVICE
|
||||
setup()
|
||||
# called PER DEVICE
|
||||
on_fit_start()
|
||||
setup('fit')
|
||||
configure_optimizers()
|
||||
on_pretrain_routine_start()
|
||||
|
||||
on_pretrain_routine_start()
|
||||
on_pretrain_routine_end()
|
||||
|
||||
# the sanity check runs here
|
||||
|
||||
on_train_start()
|
||||
for epoch in epochs:
|
||||
train_loop()
|
||||
on_train_end()
|
||||
|
||||
teardown()
|
||||
on_fit_end()
|
||||
teardown('fit')
|
||||
|
||||
def train_loop():
|
||||
on_epoch_start()
|
||||
on_train_epoch_start()
|
||||
train_outs = []
|
||||
for train_batch in train_dataloader():
|
||||
|
||||
for batch in train_dataloader():
|
||||
on_train_batch_start()
|
||||
|
||||
# ----- train_step methods -------
|
||||
out = training_step(batch)
|
||||
train_outs.append(out)
|
||||
on_before_batch_transfer()
|
||||
transfer_batch_to_device()
|
||||
on_after_batch_transfer()
|
||||
|
||||
loss = out.loss
|
||||
training_step()
|
||||
|
||||
on_before_zero_grad()
|
||||
optimizer_zero_grad()
|
||||
|
@ -1114,38 +1121,42 @@ This is the pseudocode to describe how all the hooks are called during a call to
|
|||
|
||||
optimizer_step()
|
||||
|
||||
on_train_batch_end(out)
|
||||
on_train_batch_end()
|
||||
|
||||
if should_check_val:
|
||||
val_loop()
|
||||
|
||||
# end training epoch
|
||||
training_epoch_end(outs)
|
||||
on_train_epoch_end(outs)
|
||||
training_epoch_end()
|
||||
|
||||
on_train_epoch_end()
|
||||
on_epoch_end()
|
||||
|
||||
def val_loop():
|
||||
model.eval()
|
||||
on_validation_model_eval() # calls `model.eval()`
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
on_validation_start()
|
||||
on_epoch_start()
|
||||
on_validation_epoch_start()
|
||||
val_outs = []
|
||||
for val_batch in val_dataloader():
|
||||
|
||||
for batch in val_dataloader():
|
||||
on_validation_batch_start()
|
||||
|
||||
# -------- val step methods -------
|
||||
out = validation_step(val_batch)
|
||||
val_outs.append(out)
|
||||
on_before_batch_transfer()
|
||||
transfer_batch_to_device()
|
||||
on_after_batch_transfer()
|
||||
|
||||
on_validation_batch_end(out)
|
||||
validation_step()
|
||||
|
||||
on_validation_batch_end()
|
||||
validation_epoch_end()
|
||||
|
||||
validation_epoch_end(val_outs)
|
||||
on_validation_epoch_end()
|
||||
on_epoch_end()
|
||||
on_validation_end()
|
||||
|
||||
# set up for train
|
||||
model.train()
|
||||
on_validation_model_train() # calls `model.train()`
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
backward
|
||||
|
|
|
@ -255,6 +255,18 @@ class HookedModel(BoringModel):
|
|||
'on_validation_batch_end',
|
||||
]
|
||||
|
||||
def prepare_data(self):
|
||||
self.called.append("prepare_data")
|
||||
return super().prepare_data()
|
||||
|
||||
def configure_callbacks(self):
|
||||
self.called.append("configure_callbacks")
|
||||
return super().configure_callbacks()
|
||||
|
||||
def configure_optimizers(self):
|
||||
self.called.append("configure_optimizers")
|
||||
return super().configure_optimizers()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
self.called.append("training_step")
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
@ -451,7 +463,10 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
assert model.called == []
|
||||
trainer.fit(model)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_fit',
|
||||
'configure_optimizers',
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
|
@ -504,7 +519,10 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
|
|||
assert model.called == []
|
||||
trainer.fit(model)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_fit',
|
||||
'configure_optimizers',
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
|
@ -535,6 +553,8 @@ def test_trainer_model_hook_system_validate(tmpdir):
|
|||
assert model.called == []
|
||||
trainer.validate(model, verbose=False)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_validate',
|
||||
'on_validation_model_eval',
|
||||
'on_validation_start',
|
||||
|
@ -567,6 +587,8 @@ def test_trainer_model_hook_system_test(tmpdir):
|
|||
assert model.called == []
|
||||
trainer.test(model, verbose=False)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_test',
|
||||
'on_test_model_eval',
|
||||
'on_test_start',
|
||||
|
|
Loading…
Reference in New Issue