Add predict hook test (#7973)

This commit is contained in:
Carlos Mocholí 2021-06-16 15:09:24 +02:00 committed by GitHub
parent 917cf83638
commit 4ffba600c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 17 deletions

View File

@ -261,6 +261,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))
- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973))
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))

View File

@ -47,12 +47,12 @@ class TrainerCallbackHookMixin(ABC):
def setup(self, model: LightningModule, stage: Optional[str]) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.setup(self, model, stage)
callback.setup(self, model, stage=stage)
def teardown(self, stage: Optional[str] = None) -> None:
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.teardown(self, self.lightning_module, stage)
callback.teardown(self, self.lightning_module, stage=stage)
def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""

View File

@ -47,7 +47,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_init_start(trainer),
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'fit'),
call.setup(trainer, model, stage='fit'),
call.on_configure_sharded_model(trainer, model),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
@ -96,7 +96,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_epoch_end(trainer, model),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
call.teardown(trainer, model, stage='fit'),
]
@ -119,7 +119,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_init_start(trainer),
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'test'),
call.setup(trainer, model, stage='test'),
call.on_configure_sharded_model(trainer, model),
call.on_test_start(trainer, model),
call.on_epoch_start(trainer, model),
@ -131,7 +131,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_test_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_test_end(trainer, model),
call.teardown(trainer, model, 'test'),
call.teardown(trainer, model, stage='test'),
]
@ -154,7 +154,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_init_start(trainer),
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'validate'),
call.setup(trainer, model, stage='validate'),
call.on_configure_sharded_model(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
@ -166,13 +166,10 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.teardown(trainer, model, 'validate'),
call.teardown(trainer, model, stage='validate'),
]
# TODO: add callback tests for predict and tune
def test_callbacks_configured_in_model(tmpdir):
""" Test the callback system with callbacks added through the model hook. """

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import partial
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LambdaCallback
@ -28,9 +29,13 @@ def test_lambda_call(tmpdir):
raise KeyboardInterrupt
checker = set()
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
hooks_args = {h: (lambda x: lambda *_: checker.add(x))(h) for h in hooks}
hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint")
def call(hook, *_, **__):
checker.add(hook)
hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)}
hooks_args = {h: partial(call, h) for h in hooks}
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')]
model = CustomModel()
@ -59,4 +64,4 @@ def test_lambda_call(tmpdir):
trainer.test(model)
trainer.predict(model)
assert checker == set(hooks)
assert checker == hooks

View File

@ -20,7 +20,7 @@ import pytest
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import __version__, LightningDataModule, LightningModule, Trainer
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@ -235,6 +235,22 @@ def get_members(cls):
return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')}
class HookedCallback(Callback):
def __init__(self, called):
def call(hook, *args, **kwargs):
d = {'name': f'Callback.{hook}'}
if args:
d['args'] = args
if kwargs:
d['kwargs'] = kwargs
called.append(d)
for h in get_members(Callback):
setattr(self, h, partial(call, h))
class HookedModel(BoringModel):
def __init__(self, called):
@ -246,7 +262,12 @@ class HookedModel(BoringModel):
def call(hook, fn, *args, **kwargs):
out = fn(*args, **kwargs)
called.append(hook)
d = {'name': hook}
if args:
d['args'] = args
if kwargs:
d['kwargs'] = kwargs
called.append(d)
return out
for h in pl_module_hooks:
@ -305,6 +326,25 @@ class HookedModel(BoringModel):
'on_test_batch_end',
]
@staticmethod
def _predict_batch(trainer, model, batches):
out = []
for i in range(batches):
out.extend([
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name='predict_step', args=(ANY, i)),
# TODO: `predict_step_end`
dict(name='Callback.on_predict_batch_end', args=(trainer, model, ANY, ANY, i, 0)),
dict(name='on_predict_batch_end', args=(ANY, ANY, i, 0)),
])
return out
def test_trainer_model_hook_system_fit(tmpdir):
called = []
@ -374,6 +414,7 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_fit_end',
'teardown',
]
called = [c['name'] for c in called]
assert called == expected
@ -418,6 +459,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
'on_fit_end',
'teardown',
]
called = [c['name'] for c in called]
assert called == expected
@ -459,6 +501,7 @@ def test_trainer_model_hook_system_validate(tmpdir, batches):
*(hooks if batches else []),
'teardown',
]
called = [c['name'] for c in called]
assert called == expected
@ -497,9 +540,62 @@ def test_trainer_model_hook_system_test(tmpdir):
'on_test_model_train',
'teardown',
]
called = [c['name'] for c in called]
assert called == expected
def test_trainer_model_hook_system_predict(tmpdir):
called = []
model = HookedModel(called)
callback = HookedCallback(called)
batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_predict_batches=batches,
progress_bar_refresh_rate=0,
callbacks=[callback],
)
assert called == [
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
]
trainer.predict(model)
expected = [
dict(name='Callback.on_init_start', args=(trainer, )),
dict(name='Callback.on_init_end', args=(trainer, )),
dict(name='prepare_data'),
dict(name='configure_callbacks'),
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='predict')),
dict(name='setup', kwargs=dict(stage='predict')),
dict(name='configure_sharded_model'),
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
dict(name='on_predict_dataloader'),
dict(name='predict_dataloader'),
dict(name='train', args=(False, )),
dict(name='on_predict_model_eval'),
dict(name='zero_grad'),
dict(name='Callback.on_predict_start', args=(trainer, model)),
dict(name='on_predict_start'),
# TODO: `{,Callback}.on_epoch_{start,end}`
dict(name='Callback.on_predict_epoch_start', args=(trainer, model)),
dict(name='on_predict_epoch_start'),
*model._predict_batch(trainer, model, batches),
# TODO: `predict_epoch_end`
dict(name='Callback.on_predict_epoch_end', args=(trainer, model, [[ANY] * batches])),
dict(name='on_predict_epoch_end', args=([[ANY] * batches], )),
dict(name='Callback.on_predict_end', args=(trainer, model)),
dict(name='on_predict_end'),
# TODO: `on_predict_model_train`
dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='predict')),
dict(name='teardown', kwargs=dict(stage='predict')),
]
assert called == expected
# TODO: add test for tune
def test_hooks_with_different_argument_names(tmpdir):
"""
Test that argument names can be anything in the hooks