Update fit with val hook test (#8060)
This commit is contained in:
parent
dd340a6598
commit
f1fa4c4727
|
@ -254,7 +254,7 @@ class TrainerCallbackHookMixin(ABC):
|
|||
@staticmethod
|
||||
def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool:
|
||||
parameters = list(signature(fn).parameters)
|
||||
return len(parameters) == 2 and parameters[1] != "args"
|
||||
return len(parameters) == 2 and parameters[0] != "args"
|
||||
|
||||
@staticmethod
|
||||
def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool:
|
||||
|
|
|
@ -366,7 +366,7 @@ class ResultCollection(dict):
|
|||
return self.get('_extra', {})
|
||||
|
||||
@extra.setter
|
||||
def extra(self, extra: Mapping[str, Any]) -> None:
|
||||
def extra(self, extra: Dict[str, Any]) -> None:
|
||||
|
||||
def check_fn(v):
|
||||
if v.grad_fn is not None:
|
||||
|
@ -378,7 +378,8 @@ class ResultCollection(dict):
|
|||
return v.detach()
|
||||
return v
|
||||
|
||||
extra = apply_to_collection(extra, torch.Tensor, check_fn)
|
||||
# update instead of replace to keep the extra dict reference. TODO: remove with v1.6 deprecation removal
|
||||
extra.update(apply_to_collection(extra, torch.Tensor, check_fn))
|
||||
self['_extra'] = extra
|
||||
|
||||
def log(
|
||||
|
|
|
@ -11,95 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, call, MagicMock, Mock
|
||||
from unittest.mock import call, Mock
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
@mock.patch("torch.save") # need to mock torch.save or we get pickle error
|
||||
def test_trainer_callback_hook_system_fit(_, tmpdir):
|
||||
"""Test the callback hook system for fit."""
|
||||
|
||||
model = BoringModel()
|
||||
callback_mock = MagicMock()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[callback_mock],
|
||||
max_epochs=1,
|
||||
limit_val_batches=1,
|
||||
limit_train_batches=3,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
|
||||
# check that only the to calls exists
|
||||
assert trainer.callbacks[0] == callback_mock
|
||||
assert callback_mock.method_calls == [
|
||||
call.on_init_start(trainer),
|
||||
call.on_init_end(trainer),
|
||||
]
|
||||
|
||||
# fit model
|
||||
trainer.fit(model)
|
||||
|
||||
assert callback_mock.method_calls == [
|
||||
call.on_init_start(trainer),
|
||||
call.on_init_end(trainer),
|
||||
call.on_before_accelerator_backend_setup(trainer, model),
|
||||
call.setup(trainer, model, stage='fit'),
|
||||
call.on_configure_sharded_model(trainer, model),
|
||||
call.on_fit_start(trainer, model),
|
||||
call.on_pretrain_routine_start(trainer, model),
|
||||
call.on_pretrain_routine_end(trainer, model),
|
||||
call.on_sanity_check_start(trainer, model),
|
||||
call.on_validation_start(trainer, model),
|
||||
call.on_epoch_start(trainer, model),
|
||||
call.on_validation_epoch_start(trainer, model),
|
||||
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
|
||||
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
|
||||
call.on_validation_epoch_end(trainer, model),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_end(trainer, model),
|
||||
call.on_sanity_check_end(trainer, model),
|
||||
call.on_train_start(trainer, model),
|
||||
call.on_epoch_start(trainer, model),
|
||||
call.on_train_epoch_start(trainer, model),
|
||||
call.on_batch_start(trainer, model),
|
||||
call.on_train_batch_start(trainer, model, ANY, 0, 0),
|
||||
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
|
||||
call.on_after_backward(trainer, model),
|
||||
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
|
||||
call.on_batch_end(trainer, model),
|
||||
call.on_batch_start(trainer, model),
|
||||
call.on_train_batch_start(trainer, model, ANY, 1, 0),
|
||||
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
|
||||
call.on_after_backward(trainer, model),
|
||||
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
|
||||
call.on_batch_end(trainer, model),
|
||||
call.on_batch_start(trainer, model),
|
||||
call.on_train_batch_start(trainer, model, ANY, 2, 0),
|
||||
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
|
||||
call.on_after_backward(trainer, model),
|
||||
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
|
||||
call.on_batch_end(trainer, model),
|
||||
call.on_validation_start(trainer, model),
|
||||
call.on_epoch_start(trainer, model),
|
||||
call.on_validation_epoch_start(trainer, model),
|
||||
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
|
||||
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
|
||||
call.on_validation_epoch_end(trainer, model),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_end(trainer, model),
|
||||
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
|
||||
call.on_train_epoch_end(trainer, model, ANY),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_train_end(trainer, model),
|
||||
call.on_fit_end(trainer, model),
|
||||
call.teardown(trainer, model, stage='fit'),
|
||||
]
|
||||
|
||||
|
||||
def test_callbacks_configured_in_model(tmpdir):
|
||||
""" Test the callback system with callbacks added through the model hook. """
|
||||
|
||||
|
|
|
@ -283,35 +283,38 @@ class HookedModel(BoringModel):
|
|||
pass
|
||||
|
||||
@staticmethod
|
||||
def _train_batch():
|
||||
return [
|
||||
'on_train_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'forward',
|
||||
'training_step',
|
||||
'training_step_end',
|
||||
'on_before_zero_grad',
|
||||
'optimizer_zero_grad',
|
||||
'backward',
|
||||
'on_after_backward',
|
||||
'optimizer_step',
|
||||
'on_train_batch_end',
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _val_batch():
|
||||
return [
|
||||
'on_validation_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'forward',
|
||||
'validation_step',
|
||||
'validation_step_end',
|
||||
'on_validation_batch_end',
|
||||
]
|
||||
def _train_batch(trainer, model, batches):
|
||||
out = []
|
||||
for i in range(batches):
|
||||
out.extend([
|
||||
# TODO: `on_batch_{start,end}`
|
||||
dict(name='Callback.on_batch_start', args=(trainer, model)),
|
||||
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
|
||||
dict(name='on_train_batch_start', args=(ANY, i, 0)),
|
||||
dict(name='on_before_batch_transfer', args=(ANY, None)),
|
||||
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
|
||||
dict(name='on_after_batch_transfer', args=(ANY, None)),
|
||||
dict(name='forward', args=(ANY, )),
|
||||
dict(name='training_step', args=(ANY, i)),
|
||||
dict(name='training_step_end', args=(dict(loss=ANY), )),
|
||||
dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)),
|
||||
dict(name='on_before_zero_grad', args=(ANY, )),
|
||||
dict(name='optimizer_zero_grad', args=(0, i, ANY, 0)),
|
||||
# TODO: `on_before_backward`
|
||||
dict(name='backward', args=(ANY, ANY, 0)),
|
||||
dict(name='Callback.on_after_backward', args=(trainer, model)),
|
||||
dict(name='on_after_backward'),
|
||||
# TODO: `on_before_optimizer_step`
|
||||
dict(
|
||||
name='optimizer_step',
|
||||
args=(0, i, ANY, 0, ANY),
|
||||
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False)
|
||||
),
|
||||
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
|
||||
dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)),
|
||||
dict(name='Callback.on_batch_end', args=(trainer, model)),
|
||||
])
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _eval_epoch(fn, trainer, model, batches, key):
|
||||
|
@ -372,6 +375,7 @@ class HookedModel(BoringModel):
|
|||
def test_trainer_model_hook_system_fit(tmpdir):
|
||||
called = []
|
||||
model = HookedModel(called)
|
||||
callback = HookedCallback(called)
|
||||
train_batches = 2
|
||||
val_batches = 2
|
||||
trainer = Trainer(
|
||||
|
@ -381,63 +385,89 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
limit_val_batches=val_batches,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
callbacks=[callback]
|
||||
)
|
||||
assert called == []
|
||||
trainer.fit(model)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup',
|
||||
'configure_sharded_model',
|
||||
'configure_optimizers',
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
'on_val_dataloader',
|
||||
'val_dataloader',
|
||||
'train', # eval() == train(False)
|
||||
'on_validation_model_eval',
|
||||
'zero_grad',
|
||||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
'on_validation_epoch_start',
|
||||
*(HookedModel._val_batch() * val_batches),
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_validation_end',
|
||||
'train',
|
||||
'on_validation_model_train',
|
||||
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
|
||||
'train',
|
||||
'on_train_dataloader',
|
||||
'train_dataloader',
|
||||
'on_train_start',
|
||||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
*(HookedModel._train_batch() * train_batches),
|
||||
'train', # eval() == train(False)
|
||||
'on_validation_model_eval',
|
||||
'zero_grad',
|
||||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
'on_validation_epoch_start',
|
||||
*(HookedModel._val_batch() * val_batches),
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_save_checkpoint',
|
||||
'on_validation_end',
|
||||
'train',
|
||||
'on_validation_model_train',
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_train_end',
|
||||
'on_fit_end',
|
||||
'teardown',
|
||||
assert called == [
|
||||
dict(name='Callback.on_init_start', args=(trainer, )),
|
||||
dict(name='Callback.on_init_end', args=(trainer, )),
|
||||
]
|
||||
trainer.fit(model)
|
||||
saved_ckpt = {
|
||||
'callbacks': ANY,
|
||||
'epoch': 1,
|
||||
'global_step': train_batches,
|
||||
'lr_schedulers': ANY,
|
||||
'optimizer_states': ANY,
|
||||
'pytorch-lightning_version': __version__,
|
||||
'state_dict': ANY,
|
||||
}
|
||||
expected = [
|
||||
dict(name='Callback.on_init_start', args=(trainer, )),
|
||||
dict(name='Callback.on_init_end', args=(trainer, )),
|
||||
dict(name='prepare_data'),
|
||||
dict(name='configure_callbacks'),
|
||||
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
|
||||
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')),
|
||||
dict(name='setup', kwargs=dict(stage='fit')),
|
||||
dict(name='configure_sharded_model'),
|
||||
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
|
||||
dict(name='configure_optimizers'),
|
||||
dict(name='Callback.on_fit_start', args=(trainer, model)),
|
||||
dict(name='on_fit_start'),
|
||||
dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)),
|
||||
dict(name='on_pretrain_routine_start'),
|
||||
dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
|
||||
dict(name='on_pretrain_routine_end'),
|
||||
dict(name='Callback.on_sanity_check_start', args=(trainer, model)),
|
||||
dict(name='on_val_dataloader'),
|
||||
dict(name='val_dataloader'),
|
||||
dict(name='train', args=(False, )),
|
||||
dict(name='on_validation_model_eval'),
|
||||
dict(name='zero_grad'),
|
||||
dict(name='Callback.on_validation_start', args=(trainer, model)),
|
||||
dict(name='on_validation_start'),
|
||||
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
|
||||
dict(name='Callback.on_validation_end', args=(trainer, model)),
|
||||
dict(name='on_validation_end'),
|
||||
dict(name='train'),
|
||||
dict(name='on_validation_model_train'),
|
||||
dict(name='Callback.on_sanity_check_end', args=(trainer, model)),
|
||||
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
|
||||
dict(name='train'),
|
||||
dict(name='on_train_dataloader'),
|
||||
dict(name='train_dataloader'),
|
||||
dict(name='Callback.on_train_start', args=(trainer, model)),
|
||||
dict(name='on_train_start'),
|
||||
dict(name='Callback.on_epoch_start', args=(trainer, model)),
|
||||
dict(name='on_epoch_start'),
|
||||
dict(name='Callback.on_train_epoch_start', args=(trainer, model)),
|
||||
dict(name='on_train_epoch_start'),
|
||||
*model._train_batch(trainer, model, train_batches),
|
||||
dict(name='train', args=(False, )),
|
||||
dict(name='on_validation_model_eval'),
|
||||
dict(name='zero_grad'),
|
||||
dict(name='Callback.on_validation_start', args=(trainer, model)),
|
||||
dict(name='on_validation_start'),
|
||||
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
|
||||
dict(name='Callback.on_validation_end', args=(trainer, model)),
|
||||
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end`
|
||||
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
|
||||
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
|
||||
dict(name='on_validation_end'),
|
||||
dict(name='train'),
|
||||
dict(name='on_validation_model_train'),
|
||||
dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
|
||||
dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)),
|
||||
dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
|
||||
dict(name='Callback.on_epoch_end', args=(trainer, model)),
|
||||
dict(name='on_epoch_end'),
|
||||
dict(name='Callback.on_train_end', args=(trainer, model)),
|
||||
dict(name='on_train_end'),
|
||||
dict(name='Callback.on_fit_end', args=(trainer, model)),
|
||||
dict(name='on_fit_end'),
|
||||
dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='fit')),
|
||||
dict(name='teardown', kwargs=dict(stage='fit')),
|
||||
]
|
||||
called = [c['name'] for c in called]
|
||||
assert called == expected
|
||||
|
||||
|
||||
|
@ -488,7 +518,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
|
|||
'on_train_start',
|
||||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
*(HookedModel._train_batch() * train_batches),
|
||||
*[
|
||||
h['name']
|
||||
for h in HookedModel._train_batch(trainer, model, train_batches) if not h['name'].startswith('Callback')
|
||||
],
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
|
|
Loading…
Reference in New Issue