Update fit with val hook test (#8060)

This commit is contained in:
Carlos Mocholí 2021-06-21 19:27:37 +02:00 committed by GitHub
parent dd340a6598
commit f1fa4c4727
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 123 additions and 172 deletions

View File

@ -254,7 +254,7 @@ class TrainerCallbackHookMixin(ABC):
@staticmethod
def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
return len(parameters) == 2 and parameters[1] != "args"
return len(parameters) == 2 and parameters[0] != "args"
@staticmethod
def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool:

View File

@ -366,7 +366,7 @@ class ResultCollection(dict):
return self.get('_extra', {})
@extra.setter
def extra(self, extra: Mapping[str, Any]) -> None:
def extra(self, extra: Dict[str, Any]) -> None:
def check_fn(v):
if v.grad_fn is not None:
@ -378,7 +378,8 @@ class ResultCollection(dict):
return v.detach()
return v
extra = apply_to_collection(extra, torch.Tensor, check_fn)
# update instead of replace to keep the extra dict reference. TODO: remove with v1.6 deprecation removal
extra.update(apply_to_collection(extra, torch.Tensor, check_fn))
self['_extra'] = extra
def log(

View File

@ -11,95 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import ANY, call, MagicMock, Mock
from unittest.mock import call, Mock
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
@mock.patch("torch.save") # need to mock torch.save or we get pickle error
def test_trainer_callback_hook_system_fit(_, tmpdir):
"""Test the callback hook system for fit."""
model = BoringModel()
callback_mock = MagicMock()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[callback_mock],
max_epochs=1,
limit_val_batches=1,
limit_train_batches=3,
progress_bar_refresh_rate=0,
)
# check that only the to calls exists
assert trainer.callbacks[0] == callback_mock
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
]
# fit model
trainer.fit(model)
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, stage='fit'),
call.on_configure_sharded_model(trainer, model),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
call.on_sanity_check_start(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
call.on_train_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_train_epoch_start(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 1, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 2, 0),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, stage='fit'),
]
def test_callbacks_configured_in_model(tmpdir):
""" Test the callback system with callbacks added through the model hook. """

View File

@ -283,35 +283,38 @@ class HookedModel(BoringModel):
pass
@staticmethod
def _train_batch():
return [
'on_train_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'forward',
'training_step',
'training_step_end',
'on_before_zero_grad',
'optimizer_zero_grad',
'backward',
'on_after_backward',
'optimizer_step',
'on_train_batch_end',
]
@staticmethod
def _val_batch():
return [
'on_validation_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'forward',
'validation_step',
'validation_step_end',
'on_validation_batch_end',
]
def _train_batch(trainer, model, batches):
out = []
for i in range(batches):
out.extend([
# TODO: `on_batch_{start,end}`
dict(name='Callback.on_batch_start', args=(trainer, model)),
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)),
dict(name='on_before_zero_grad', args=(ANY, )),
dict(name='optimizer_zero_grad', args=(0, i, ANY, 0)),
# TODO: `on_before_backward`
dict(name='backward', args=(ANY, ANY, 0)),
dict(name='Callback.on_after_backward', args=(trainer, model)),
dict(name='on_after_backward'),
# TODO: `on_before_optimizer_step`
dict(
name='optimizer_step',
args=(0, i, ANY, 0, ANY),
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False)
),
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)),
dict(name='Callback.on_batch_end', args=(trainer, model)),
])
return out
@staticmethod
def _eval_epoch(fn, trainer, model, batches, key):
@ -372,6 +375,7 @@ class HookedModel(BoringModel):
def test_trainer_model_hook_system_fit(tmpdir):
called = []
model = HookedModel(called)
callback = HookedCallback(called)
train_batches = 2
val_batches = 2
trainer = Trainer(
@ -381,63 +385,89 @@ def test_trainer_model_hook_system_fit(tmpdir):
limit_val_batches=val_batches,
progress_bar_refresh_rate=0,
weights_summary=None,
callbacks=[callback]
)
assert called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup',
'configure_sharded_model',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_val_dataloader',
'val_dataloader',
'train', # eval() == train(False)
'on_validation_model_eval',
'zero_grad',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
*(HookedModel._val_batch() * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_validation_end',
'train',
'on_validation_model_train',
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
'train',
'on_train_dataloader',
'train_dataloader',
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
*(HookedModel._train_batch() * train_batches),
'train', # eval() == train(False)
'on_validation_model_eval',
'zero_grad',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
*(HookedModel._val_batch() * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_save_checkpoint',
'on_validation_end',
'train',
'on_validation_model_train',
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_train_end',
'on_fit_end',
'teardown',
assert called == [
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
]
trainer.fit(model)
saved_ckpt = {
'callbacks': ANY,
'epoch': 1,
'global_step': train_batches,
'lr_schedulers': ANY,
'optimizer_states': ANY,
'pytorch-lightning_version': __version__,
'state_dict': ANY,
}
expected = [
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
dict(name='prepare_data'),
dict(name='configure_callbacks'),
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')),
dict(name='setup', kwargs=dict(stage='fit')),
dict(name='configure_sharded_model'),
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
dict(name='configure_optimizers'),
dict(name='Callback.on_fit_start', args=(trainer, model)),
dict(name='on_fit_start'),
dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)),
dict(name='on_pretrain_routine_start'),
dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
dict(name='on_pretrain_routine_end'),
dict(name='Callback.on_sanity_check_start', args=(trainer, model)),
dict(name='on_val_dataloader'),
dict(name='val_dataloader'),
dict(name='train', args=(False, )),
dict(name='on_validation_model_eval'),
dict(name='zero_grad'),
dict(name='Callback.on_validation_start', args=(trainer, model)),
dict(name='on_validation_start'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
dict(name='Callback.on_validation_end', args=(trainer, model)),
dict(name='on_validation_end'),
dict(name='train'),
dict(name='on_validation_model_train'),
dict(name='Callback.on_sanity_check_end', args=(trainer, model)),
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
dict(name='train'),
dict(name='on_train_dataloader'),
dict(name='train_dataloader'),
dict(name='Callback.on_train_start', args=(trainer, model)),
dict(name='on_train_start'),
dict(name='Callback.on_epoch_start', args=(trainer, model)),
dict(name='on_epoch_start'),
dict(name='Callback.on_train_epoch_start', args=(trainer, model)),
dict(name='on_train_epoch_start'),
*model._train_batch(trainer, model, train_batches),
dict(name='train', args=(False, )),
dict(name='on_validation_model_eval'),
dict(name='zero_grad'),
dict(name='Callback.on_validation_start', args=(trainer, model)),
dict(name='on_validation_start'),
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
dict(name='Callback.on_validation_end', args=(trainer, model)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end`
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
dict(name='on_validation_end'),
dict(name='train'),
dict(name='on_validation_model_train'),
dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)),
dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
dict(name='Callback.on_epoch_end', args=(trainer, model)),
dict(name='on_epoch_end'),
dict(name='Callback.on_train_end', args=(trainer, model)),
dict(name='on_train_end'),
dict(name='Callback.on_fit_end', args=(trainer, model)),
dict(name='on_fit_end'),
dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='fit')),
dict(name='teardown', kwargs=dict(stage='fit')),
]
called = [c['name'] for c in called]
assert called == expected
@ -488,7 +518,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
*(HookedModel._train_batch() * train_batches),
*[
h['name']
for h in HookedModel._train_batch(trainer, model, train_batches) if not h['name'].startswith('Callback')
],
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',