diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fb3acd19b..bbdd6ca104 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -261,6 +261,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941)) +- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973)) + + - Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931)) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 23df26b410..3b5b4d4038 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -47,12 +47,12 @@ class TrainerCallbackHookMixin(ABC): def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.setup(self, model, stage) + callback.setup(self, model, stage=stage) def teardown(self, stage: Optional[str] = None) -> None: """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.teardown(self, self.lightning_module, stage) + callback.teardown(self, self.lightning_module, stage=stage) def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index a22e72ce09..13ead5a69c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -47,7 +47,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_init_start(trainer), call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'fit'), + call.setup(trainer, model, stage='fit'), call.on_configure_sharded_model(trainer, model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), @@ -96,7 +96,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_epoch_end(trainer, model), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), - call.teardown(trainer, model, 'fit'), + call.teardown(trainer, model, stage='fit'), ] @@ -119,7 +119,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_init_start(trainer), call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'test'), + call.setup(trainer, model, stage='test'), call.on_configure_sharded_model(trainer, model), call.on_test_start(trainer, model), call.on_epoch_start(trainer, model), @@ -131,7 +131,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), - call.teardown(trainer, model, 'test'), + call.teardown(trainer, model, stage='test'), ] @@ -154,7 +154,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_init_start(trainer), call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'validate'), + call.setup(trainer, model, stage='validate'), call.on_configure_sharded_model(trainer, model), call.on_validation_start(trainer, model), call.on_epoch_start(trainer, model), @@ -166,13 +166,10 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), - call.teardown(trainer, model, 'validate'), + call.teardown(trainer, model, stage='validate'), ] -# TODO: add callback tests for predict and tune - - def test_callbacks_configured_in_model(tmpdir): """ Test the callback system with callbacks added through the model hook. """ diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 8d9f85fa56..845846dfd1 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from functools import partial from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback @@ -28,9 +29,13 @@ def test_lambda_call(tmpdir): raise KeyboardInterrupt checker = set() - hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] - hooks_args = {h: (lambda x: lambda *_: checker.add(x))(h) for h in hooks} - hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint") + + def call(hook, *_, **__): + checker.add(hook) + + hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)} + hooks_args = {h: partial(call, h) for h in hooks} + hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')] model = CustomModel() @@ -59,4 +64,4 @@ def test_lambda_call(tmpdir): trainer.test(model) trainer.predict(model) - assert checker == set(hooks) + assert checker == hooks diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6413ca8c93..7ab93e9ad2 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -20,7 +20,7 @@ import pytest import torch from torch.utils.data import DataLoader -from pytorch_lightning import __version__, LightningDataModule, LightningModule, Trainer +from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -235,6 +235,22 @@ def get_members(cls): return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')} +class HookedCallback(Callback): + + def __init__(self, called): + + def call(hook, *args, **kwargs): + d = {'name': f'Callback.{hook}'} + if args: + d['args'] = args + if kwargs: + d['kwargs'] = kwargs + called.append(d) + + for h in get_members(Callback): + setattr(self, h, partial(call, h)) + + class HookedModel(BoringModel): def __init__(self, called): @@ -246,7 +262,12 @@ class HookedModel(BoringModel): def call(hook, fn, *args, **kwargs): out = fn(*args, **kwargs) - called.append(hook) + d = {'name': hook} + if args: + d['args'] = args + if kwargs: + d['kwargs'] = kwargs + called.append(d) return out for h in pl_module_hooks: @@ -305,6 +326,25 @@ class HookedModel(BoringModel): 'on_test_batch_end', ] + @staticmethod + def _predict_batch(trainer, model, batches): + out = [] + for i in range(batches): + out.extend([ + # TODO: `{,Callback}.on_batch_{start,end}` + dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)), + dict(name='on_predict_batch_start', args=(ANY, i, 0)), + dict(name='on_before_batch_transfer', args=(ANY, None)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), + dict(name='on_after_batch_transfer', args=(ANY, None)), + dict(name='forward', args=(ANY, )), + dict(name='predict_step', args=(ANY, i)), + # TODO: `predict_step_end` + dict(name='Callback.on_predict_batch_end', args=(trainer, model, ANY, ANY, i, 0)), + dict(name='on_predict_batch_end', args=(ANY, ANY, i, 0)), + ]) + return out + def test_trainer_model_hook_system_fit(tmpdir): called = [] @@ -374,6 +414,7 @@ def test_trainer_model_hook_system_fit(tmpdir): 'on_fit_end', 'teardown', ] + called = [c['name'] for c in called] assert called == expected @@ -418,6 +459,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): 'on_fit_end', 'teardown', ] + called = [c['name'] for c in called] assert called == expected @@ -459,6 +501,7 @@ def test_trainer_model_hook_system_validate(tmpdir, batches): *(hooks if batches else []), 'teardown', ] + called = [c['name'] for c in called] assert called == expected @@ -497,9 +540,62 @@ def test_trainer_model_hook_system_test(tmpdir): 'on_test_model_train', 'teardown', ] + called = [c['name'] for c in called] assert called == expected +def test_trainer_model_hook_system_predict(tmpdir): + called = [] + model = HookedModel(called) + callback = HookedCallback(called) + batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_predict_batches=batches, + progress_bar_refresh_rate=0, + callbacks=[callback], + ) + assert called == [ + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + ] + trainer.predict(model) + expected = [ + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + dict(name='prepare_data'), + dict(name='configure_callbacks'), + dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), + dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='predict')), + dict(name='setup', kwargs=dict(stage='predict')), + dict(name='configure_sharded_model'), + dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), + dict(name='on_predict_dataloader'), + dict(name='predict_dataloader'), + dict(name='train', args=(False, )), + dict(name='on_predict_model_eval'), + dict(name='zero_grad'), + dict(name='Callback.on_predict_start', args=(trainer, model)), + dict(name='on_predict_start'), + # TODO: `{,Callback}.on_epoch_{start,end}` + dict(name='Callback.on_predict_epoch_start', args=(trainer, model)), + dict(name='on_predict_epoch_start'), + *model._predict_batch(trainer, model, batches), + # TODO: `predict_epoch_end` + dict(name='Callback.on_predict_epoch_end', args=(trainer, model, [[ANY] * batches])), + dict(name='on_predict_epoch_end', args=([[ANY] * batches], )), + dict(name='Callback.on_predict_end', args=(trainer, model)), + dict(name='on_predict_end'), + # TODO: `on_predict_model_train` + dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='predict')), + dict(name='teardown', kwargs=dict(stage='predict')), + ] + assert called == expected + + +# TODO: add test for tune + + def test_hooks_with_different_argument_names(tmpdir): """ Test that argument names can be anything in the hooks