diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 288f6b0f8c..1f17308df7 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -254,7 +254,7 @@ class TrainerCallbackHookMixin(ABC): @staticmethod def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) - return len(parameters) == 2 and parameters[1] != "args" + return len(parameters) == 2 and parameters[0] != "args" @staticmethod def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 0cbd663e07..5e08a82e4b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -366,7 +366,7 @@ class ResultCollection(dict): return self.get('_extra', {}) @extra.setter - def extra(self, extra: Mapping[str, Any]) -> None: + def extra(self, extra: Dict[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: @@ -378,7 +378,8 @@ class ResultCollection(dict): return v.detach() return v - extra = apply_to_collection(extra, torch.Tensor, check_fn) + # update instead of replace to keep the extra dict reference. TODO: remove with v1.6 deprecation removal + extra.update(apply_to_collection(extra, torch.Tensor, check_fn)) self['_extra'] = extra def log( diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index ed6962a0d4..57fdd1bf66 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -11,95 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock -from unittest.mock import ANY, call, MagicMock, Mock +from unittest.mock import call, Mock from pytorch_lightning import Trainer from tests.helpers import BoringModel -@mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_hook_system_fit(_, tmpdir): - """Test the callback hook system for fit.""" - - model = BoringModel() - callback_mock = MagicMock() - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[callback_mock], - max_epochs=1, - limit_val_batches=1, - limit_train_batches=3, - progress_bar_refresh_rate=0, - ) - - # check that only the to calls exists - assert trainer.callbacks[0] == callback_mock - assert callback_mock.method_calls == [ - call.on_init_start(trainer), - call.on_init_end(trainer), - ] - - # fit model - trainer.fit(model) - - assert callback_mock.method_calls == [ - call.on_init_start(trainer), - call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, stage='fit'), - call.on_configure_sharded_model(trainer, model), - call.on_fit_start(trainer, model), - call.on_pretrain_routine_start(trainer, model), - call.on_pretrain_routine_end(trainer, model), - call.on_sanity_check_start(trainer, model), - call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_validation_epoch_start(trainer, model), - call.on_validation_batch_start(trainer, model, ANY, 0, 0), - call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), - call.on_epoch_end(trainer, model), - call.on_validation_end(trainer, model), - call.on_sanity_check_end(trainer, model), - call.on_train_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_train_epoch_start(trainer, model), - call.on_batch_start(trainer, model), - call.on_train_batch_start(trainer, model, ANY, 0, 0), - call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), - call.on_after_backward(trainer, model), - call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_batch_end(trainer, model), - call.on_batch_start(trainer, model), - call.on_train_batch_start(trainer, model, ANY, 1, 0), - call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), - call.on_after_backward(trainer, model), - call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_batch_end(trainer, model), - call.on_batch_start(trainer, model), - call.on_train_batch_start(trainer, model, ANY, 2, 0), - call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), - call.on_after_backward(trainer, model), - call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), - call.on_batch_end(trainer, model), - call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_validation_epoch_start(trainer, model), - call.on_validation_batch_start(trainer, model, ANY, 0, 0), - call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), - call.on_epoch_end(trainer, model), - call.on_validation_end(trainer, model), - call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC - call.on_train_epoch_end(trainer, model, ANY), - call.on_epoch_end(trainer, model), - call.on_train_end(trainer, model), - call.on_fit_end(trainer, model), - call.teardown(trainer, model, stage='fit'), - ] - - def test_callbacks_configured_in_model(tmpdir): """ Test the callback system with callbacks added through the model hook. """ diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d724a15504..37e4867c7b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -283,35 +283,38 @@ class HookedModel(BoringModel): pass @staticmethod - def _train_batch(): - return [ - 'on_train_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'forward', - 'training_step', - 'training_step_end', - 'on_before_zero_grad', - 'optimizer_zero_grad', - 'backward', - 'on_after_backward', - 'optimizer_step', - 'on_train_batch_end', - ] - - @staticmethod - def _val_batch(): - return [ - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'forward', - 'validation_step', - 'validation_step_end', - 'on_validation_batch_end', - ] + def _train_batch(trainer, model, batches): + out = [] + for i in range(batches): + out.extend([ + # TODO: `on_batch_{start,end}` + dict(name='Callback.on_batch_start', args=(trainer, model)), + dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), + dict(name='on_train_batch_start', args=(ANY, i, 0)), + dict(name='on_before_batch_transfer', args=(ANY, None)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), + dict(name='on_after_batch_transfer', args=(ANY, None)), + dict(name='forward', args=(ANY, )), + dict(name='training_step', args=(ANY, i)), + dict(name='training_step_end', args=(dict(loss=ANY), )), + dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)), + dict(name='on_before_zero_grad', args=(ANY, )), + dict(name='optimizer_zero_grad', args=(0, i, ANY, 0)), + # TODO: `on_before_backward` + dict(name='backward', args=(ANY, ANY, 0)), + dict(name='Callback.on_after_backward', args=(trainer, model)), + dict(name='on_after_backward'), + # TODO: `on_before_optimizer_step` + dict( + name='optimizer_step', + args=(0, i, ANY, 0, ANY), + kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False) + ), + dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)), + dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)), + dict(name='Callback.on_batch_end', args=(trainer, model)), + ]) + return out @staticmethod def _eval_epoch(fn, trainer, model, batches, key): @@ -372,6 +375,7 @@ class HookedModel(BoringModel): def test_trainer_model_hook_system_fit(tmpdir): called = [] model = HookedModel(called) + callback = HookedCallback(called) train_batches = 2 val_batches = 2 trainer = Trainer( @@ -381,63 +385,89 @@ def test_trainer_model_hook_system_fit(tmpdir): limit_val_batches=val_batches, progress_bar_refresh_rate=0, weights_summary=None, + callbacks=[callback] ) - assert called == [] - trainer.fit(model) - expected = [ - 'prepare_data', - 'configure_callbacks', - 'setup', - 'configure_sharded_model', - 'configure_optimizers', - 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', - 'on_val_dataloader', - 'val_dataloader', - 'train', # eval() == train(False) - 'on_validation_model_eval', - 'zero_grad', - 'on_validation_start', - 'on_epoch_start', - 'on_validation_epoch_start', - *(HookedModel._val_batch() * val_batches), - 'validation_epoch_end', - 'on_validation_epoch_end', - 'on_epoch_end', - 'on_validation_end', - 'train', - 'on_validation_model_train', - # duplicate `train` because `_run_train` calls it again in case validation wasn't run - 'train', - 'on_train_dataloader', - 'train_dataloader', - 'on_train_start', - 'on_epoch_start', - 'on_train_epoch_start', - *(HookedModel._train_batch() * train_batches), - 'train', # eval() == train(False) - 'on_validation_model_eval', - 'zero_grad', - 'on_validation_start', - 'on_epoch_start', - 'on_validation_epoch_start', - *(HookedModel._val_batch() * val_batches), - 'validation_epoch_end', - 'on_validation_epoch_end', - 'on_epoch_end', - 'on_save_checkpoint', - 'on_validation_end', - 'train', - 'on_validation_model_train', - 'training_epoch_end', - 'on_train_epoch_end', - 'on_epoch_end', - 'on_train_end', - 'on_fit_end', - 'teardown', + assert called == [ + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + ] + trainer.fit(model) + saved_ckpt = { + 'callbacks': ANY, + 'epoch': 1, + 'global_step': train_batches, + 'lr_schedulers': ANY, + 'optimizer_states': ANY, + 'pytorch-lightning_version': __version__, + 'state_dict': ANY, + } + expected = [ + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + dict(name='prepare_data'), + dict(name='configure_callbacks'), + dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), + dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')), + dict(name='setup', kwargs=dict(stage='fit')), + dict(name='configure_sharded_model'), + dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), + dict(name='configure_optimizers'), + dict(name='Callback.on_fit_start', args=(trainer, model)), + dict(name='on_fit_start'), + dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)), + dict(name='on_pretrain_routine_start'), + dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)), + dict(name='on_pretrain_routine_end'), + dict(name='Callback.on_sanity_check_start', args=(trainer, model)), + dict(name='on_val_dataloader'), + dict(name='val_dataloader'), + dict(name='train', args=(False, )), + dict(name='on_validation_model_eval'), + dict(name='zero_grad'), + dict(name='Callback.on_validation_start', args=(trainer, model)), + dict(name='on_validation_start'), + *model._eval_epoch('validation', trainer, model, val_batches, 'x'), + dict(name='Callback.on_validation_end', args=(trainer, model)), + dict(name='on_validation_end'), + dict(name='train'), + dict(name='on_validation_model_train'), + dict(name='Callback.on_sanity_check_end', args=(trainer, model)), + # duplicate `train` because `_run_train` calls it again in case validation wasn't run + dict(name='train'), + dict(name='on_train_dataloader'), + dict(name='train_dataloader'), + dict(name='Callback.on_train_start', args=(trainer, model)), + dict(name='on_train_start'), + dict(name='Callback.on_epoch_start', args=(trainer, model)), + dict(name='on_epoch_start'), + dict(name='Callback.on_train_epoch_start', args=(trainer, model)), + dict(name='on_train_epoch_start'), + *model._train_batch(trainer, model, train_batches), + dict(name='train', args=(False, )), + dict(name='on_validation_model_eval'), + dict(name='zero_grad'), + dict(name='Callback.on_validation_start', args=(trainer, model)), + dict(name='on_validation_start'), + *model._eval_epoch('validation', trainer, model, val_batches, 'x'), + dict(name='Callback.on_validation_end', args=(trainer, model)), + # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end` + dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), + dict(name='on_save_checkpoint', args=(saved_ckpt, )), + dict(name='on_validation_end'), + dict(name='train'), + dict(name='on_validation_model_train'), + dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )), + dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)), + dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )), + dict(name='Callback.on_epoch_end', args=(trainer, model)), + dict(name='on_epoch_end'), + dict(name='Callback.on_train_end', args=(trainer, model)), + dict(name='on_train_end'), + dict(name='Callback.on_fit_end', args=(trainer, model)), + dict(name='on_fit_end'), + dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='fit')), + dict(name='teardown', kwargs=dict(stage='fit')), ] - called = [c['name'] for c in called] assert called == expected @@ -488,7 +518,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - *(HookedModel._train_batch() * train_batches), + *[ + h['name'] + for h in HookedModel._train_batch(trainer, model, train_batches) if not h['name'].startswith('Callback') + ], 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end',