Add predict hook test (#7973)
This commit is contained in:
parent
917cf83638
commit
4ffba600c9
|
@ -261,6 +261,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))
|
||||
|
||||
|
||||
- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973))
|
||||
|
||||
|
||||
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
|
||||
|
||||
|
||||
|
|
|
@ -47,12 +47,12 @@ class TrainerCallbackHookMixin(ABC):
|
|||
def setup(self, model: LightningModule, stage: Optional[str]) -> None:
|
||||
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
|
||||
for callback in self.callbacks:
|
||||
callback.setup(self, model, stage)
|
||||
callback.setup(self, model, stage=stage)
|
||||
|
||||
def teardown(self, stage: Optional[str] = None) -> None:
|
||||
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
|
||||
for callback in self.callbacks:
|
||||
callback.teardown(self, self.lightning_module, stage)
|
||||
callback.teardown(self, self.lightning_module, stage=stage)
|
||||
|
||||
def on_init_start(self):
|
||||
"""Called when the trainer initialization begins, model has not yet been set."""
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
|||
call.on_init_start(trainer),
|
||||
call.on_init_end(trainer),
|
||||
call.on_before_accelerator_backend_setup(trainer, model),
|
||||
call.setup(trainer, model, 'fit'),
|
||||
call.setup(trainer, model, stage='fit'),
|
||||
call.on_configure_sharded_model(trainer, model),
|
||||
call.on_fit_start(trainer, model),
|
||||
call.on_pretrain_routine_start(trainer, model),
|
||||
|
@ -96,7 +96,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
|||
call.on_epoch_end(trainer, model),
|
||||
call.on_train_end(trainer, model),
|
||||
call.on_fit_end(trainer, model),
|
||||
call.teardown(trainer, model, 'fit'),
|
||||
call.teardown(trainer, model, stage='fit'),
|
||||
]
|
||||
|
||||
|
||||
|
@ -119,7 +119,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
|
|||
call.on_init_start(trainer),
|
||||
call.on_init_end(trainer),
|
||||
call.on_before_accelerator_backend_setup(trainer, model),
|
||||
call.setup(trainer, model, 'test'),
|
||||
call.setup(trainer, model, stage='test'),
|
||||
call.on_configure_sharded_model(trainer, model),
|
||||
call.on_test_start(trainer, model),
|
||||
call.on_epoch_start(trainer, model),
|
||||
|
@ -131,7 +131,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
|
|||
call.on_test_epoch_end(trainer, model),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_test_end(trainer, model),
|
||||
call.teardown(trainer, model, 'test'),
|
||||
call.teardown(trainer, model, stage='test'),
|
||||
]
|
||||
|
||||
|
||||
|
@ -154,7 +154,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
|
|||
call.on_init_start(trainer),
|
||||
call.on_init_end(trainer),
|
||||
call.on_before_accelerator_backend_setup(trainer, model),
|
||||
call.setup(trainer, model, 'validate'),
|
||||
call.setup(trainer, model, stage='validate'),
|
||||
call.on_configure_sharded_model(trainer, model),
|
||||
call.on_validation_start(trainer, model),
|
||||
call.on_epoch_start(trainer, model),
|
||||
|
@ -166,13 +166,10 @@ def test_trainer_callback_hook_system_validate(tmpdir):
|
|||
call.on_validation_epoch_end(trainer, model),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_end(trainer, model),
|
||||
call.teardown(trainer, model, 'validate'),
|
||||
call.teardown(trainer, model, stage='validate'),
|
||||
]
|
||||
|
||||
|
||||
# TODO: add callback tests for predict and tune
|
||||
|
||||
|
||||
def test_callbacks_configured_in_model(tmpdir):
|
||||
""" Test the callback system with callbacks added through the model hook. """
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from functools import partial
|
||||
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning.callbacks import Callback, LambdaCallback
|
||||
|
@ -28,9 +29,13 @@ def test_lambda_call(tmpdir):
|
|||
raise KeyboardInterrupt
|
||||
|
||||
checker = set()
|
||||
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
|
||||
hooks_args = {h: (lambda x: lambda *_: checker.add(x))(h) for h in hooks}
|
||||
hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint")
|
||||
|
||||
def call(hook, *_, **__):
|
||||
checker.add(hook)
|
||||
|
||||
hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)}
|
||||
hooks_args = {h: partial(call, h) for h in hooks}
|
||||
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')]
|
||||
|
||||
model = CustomModel()
|
||||
|
||||
|
@ -59,4 +64,4 @@ def test_lambda_call(tmpdir):
|
|||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
|
||||
assert checker == set(hooks)
|
||||
assert checker == hooks
|
||||
|
|
|
@ -20,7 +20,7 @@ import pytest
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import __version__, LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer
|
||||
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -235,6 +235,22 @@ def get_members(cls):
|
|||
return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')}
|
||||
|
||||
|
||||
class HookedCallback(Callback):
|
||||
|
||||
def __init__(self, called):
|
||||
|
||||
def call(hook, *args, **kwargs):
|
||||
d = {'name': f'Callback.{hook}'}
|
||||
if args:
|
||||
d['args'] = args
|
||||
if kwargs:
|
||||
d['kwargs'] = kwargs
|
||||
called.append(d)
|
||||
|
||||
for h in get_members(Callback):
|
||||
setattr(self, h, partial(call, h))
|
||||
|
||||
|
||||
class HookedModel(BoringModel):
|
||||
|
||||
def __init__(self, called):
|
||||
|
@ -246,7 +262,12 @@ class HookedModel(BoringModel):
|
|||
|
||||
def call(hook, fn, *args, **kwargs):
|
||||
out = fn(*args, **kwargs)
|
||||
called.append(hook)
|
||||
d = {'name': hook}
|
||||
if args:
|
||||
d['args'] = args
|
||||
if kwargs:
|
||||
d['kwargs'] = kwargs
|
||||
called.append(d)
|
||||
return out
|
||||
|
||||
for h in pl_module_hooks:
|
||||
|
@ -305,6 +326,25 @@ class HookedModel(BoringModel):
|
|||
'on_test_batch_end',
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _predict_batch(trainer, model, batches):
|
||||
out = []
|
||||
for i in range(batches):
|
||||
out.extend([
|
||||
# TODO: `{,Callback}.on_batch_{start,end}`
|
||||
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
|
||||
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
|
||||
dict(name='on_before_batch_transfer', args=(ANY, None)),
|
||||
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
|
||||
dict(name='on_after_batch_transfer', args=(ANY, None)),
|
||||
dict(name='forward', args=(ANY, )),
|
||||
dict(name='predict_step', args=(ANY, i)),
|
||||
# TODO: `predict_step_end`
|
||||
dict(name='Callback.on_predict_batch_end', args=(trainer, model, ANY, ANY, i, 0)),
|
||||
dict(name='on_predict_batch_end', args=(ANY, ANY, i, 0)),
|
||||
])
|
||||
return out
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_fit(tmpdir):
|
||||
called = []
|
||||
|
@ -374,6 +414,7 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
'on_fit_end',
|
||||
'teardown',
|
||||
]
|
||||
called = [c['name'] for c in called]
|
||||
assert called == expected
|
||||
|
||||
|
||||
|
@ -418,6 +459,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
|
|||
'on_fit_end',
|
||||
'teardown',
|
||||
]
|
||||
called = [c['name'] for c in called]
|
||||
assert called == expected
|
||||
|
||||
|
||||
|
@ -459,6 +501,7 @@ def test_trainer_model_hook_system_validate(tmpdir, batches):
|
|||
*(hooks if batches else []),
|
||||
'teardown',
|
||||
]
|
||||
called = [c['name'] for c in called]
|
||||
assert called == expected
|
||||
|
||||
|
||||
|
@ -497,9 +540,62 @@ def test_trainer_model_hook_system_test(tmpdir):
|
|||
'on_test_model_train',
|
||||
'teardown',
|
||||
]
|
||||
called = [c['name'] for c in called]
|
||||
assert called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_predict(tmpdir):
|
||||
called = []
|
||||
model = HookedModel(called)
|
||||
callback = HookedCallback(called)
|
||||
batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_predict_batches=batches,
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[callback],
|
||||
)
|
||||
assert called == [
|
||||
dict(name='Callback.on_init_start', args=(trainer, )),
|
||||
dict(name='Callback.on_init_end', args=(trainer, )),
|
||||
]
|
||||
trainer.predict(model)
|
||||
expected = [
|
||||
dict(name='Callback.on_init_start', args=(trainer, )),
|
||||
dict(name='Callback.on_init_end', args=(trainer, )),
|
||||
dict(name='prepare_data'),
|
||||
dict(name='configure_callbacks'),
|
||||
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
|
||||
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='predict')),
|
||||
dict(name='setup', kwargs=dict(stage='predict')),
|
||||
dict(name='configure_sharded_model'),
|
||||
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
|
||||
dict(name='on_predict_dataloader'),
|
||||
dict(name='predict_dataloader'),
|
||||
dict(name='train', args=(False, )),
|
||||
dict(name='on_predict_model_eval'),
|
||||
dict(name='zero_grad'),
|
||||
dict(name='Callback.on_predict_start', args=(trainer, model)),
|
||||
dict(name='on_predict_start'),
|
||||
# TODO: `{,Callback}.on_epoch_{start,end}`
|
||||
dict(name='Callback.on_predict_epoch_start', args=(trainer, model)),
|
||||
dict(name='on_predict_epoch_start'),
|
||||
*model._predict_batch(trainer, model, batches),
|
||||
# TODO: `predict_epoch_end`
|
||||
dict(name='Callback.on_predict_epoch_end', args=(trainer, model, [[ANY] * batches])),
|
||||
dict(name='on_predict_epoch_end', args=([[ANY] * batches], )),
|
||||
dict(name='Callback.on_predict_end', args=(trainer, model)),
|
||||
dict(name='on_predict_end'),
|
||||
# TODO: `on_predict_model_train`
|
||||
dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='predict')),
|
||||
dict(name='teardown', kwargs=dict(stage='predict')),
|
||||
]
|
||||
assert called == expected
|
||||
|
||||
|
||||
# TODO: add test for tune
|
||||
|
||||
|
||||
def test_hooks_with_different_argument_names(tmpdir):
|
||||
"""
|
||||
Test that argument names can be anything in the hooks
|
||||
|
|
Loading…
Reference in New Issue