Add predict hook test (#7973)
This commit is contained in:
parent
917cf83638
commit
4ffba600c9
|
@ -261,6 +261,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))
|
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))
|
||||||
|
|
||||||
|
|
||||||
|
- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973))
|
||||||
|
|
||||||
|
|
||||||
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
|
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,12 +47,12 @@ class TrainerCallbackHookMixin(ABC):
|
||||||
def setup(self, model: LightningModule, stage: Optional[str]) -> None:
|
def setup(self, model: LightningModule, stage: Optional[str]) -> None:
|
||||||
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
|
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.setup(self, model, stage)
|
callback.setup(self, model, stage=stage)
|
||||||
|
|
||||||
def teardown(self, stage: Optional[str] = None) -> None:
|
def teardown(self, stage: Optional[str] = None) -> None:
|
||||||
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
|
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.teardown(self, self.lightning_module, stage)
|
callback.teardown(self, self.lightning_module, stage=stage)
|
||||||
|
|
||||||
def on_init_start(self):
|
def on_init_start(self):
|
||||||
"""Called when the trainer initialization begins, model has not yet been set."""
|
"""Called when the trainer initialization begins, model has not yet been set."""
|
||||||
|
|
|
@ -47,7 +47,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
||||||
call.on_init_start(trainer),
|
call.on_init_start(trainer),
|
||||||
call.on_init_end(trainer),
|
call.on_init_end(trainer),
|
||||||
call.on_before_accelerator_backend_setup(trainer, model),
|
call.on_before_accelerator_backend_setup(trainer, model),
|
||||||
call.setup(trainer, model, 'fit'),
|
call.setup(trainer, model, stage='fit'),
|
||||||
call.on_configure_sharded_model(trainer, model),
|
call.on_configure_sharded_model(trainer, model),
|
||||||
call.on_fit_start(trainer, model),
|
call.on_fit_start(trainer, model),
|
||||||
call.on_pretrain_routine_start(trainer, model),
|
call.on_pretrain_routine_start(trainer, model),
|
||||||
|
@ -96,7 +96,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
||||||
call.on_epoch_end(trainer, model),
|
call.on_epoch_end(trainer, model),
|
||||||
call.on_train_end(trainer, model),
|
call.on_train_end(trainer, model),
|
||||||
call.on_fit_end(trainer, model),
|
call.on_fit_end(trainer, model),
|
||||||
call.teardown(trainer, model, 'fit'),
|
call.teardown(trainer, model, stage='fit'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
|
||||||
call.on_init_start(trainer),
|
call.on_init_start(trainer),
|
||||||
call.on_init_end(trainer),
|
call.on_init_end(trainer),
|
||||||
call.on_before_accelerator_backend_setup(trainer, model),
|
call.on_before_accelerator_backend_setup(trainer, model),
|
||||||
call.setup(trainer, model, 'test'),
|
call.setup(trainer, model, stage='test'),
|
||||||
call.on_configure_sharded_model(trainer, model),
|
call.on_configure_sharded_model(trainer, model),
|
||||||
call.on_test_start(trainer, model),
|
call.on_test_start(trainer, model),
|
||||||
call.on_epoch_start(trainer, model),
|
call.on_epoch_start(trainer, model),
|
||||||
|
@ -131,7 +131,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
|
||||||
call.on_test_epoch_end(trainer, model),
|
call.on_test_epoch_end(trainer, model),
|
||||||
call.on_epoch_end(trainer, model),
|
call.on_epoch_end(trainer, model),
|
||||||
call.on_test_end(trainer, model),
|
call.on_test_end(trainer, model),
|
||||||
call.teardown(trainer, model, 'test'),
|
call.teardown(trainer, model, stage='test'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
|
||||||
call.on_init_start(trainer),
|
call.on_init_start(trainer),
|
||||||
call.on_init_end(trainer),
|
call.on_init_end(trainer),
|
||||||
call.on_before_accelerator_backend_setup(trainer, model),
|
call.on_before_accelerator_backend_setup(trainer, model),
|
||||||
call.setup(trainer, model, 'validate'),
|
call.setup(trainer, model, stage='validate'),
|
||||||
call.on_configure_sharded_model(trainer, model),
|
call.on_configure_sharded_model(trainer, model),
|
||||||
call.on_validation_start(trainer, model),
|
call.on_validation_start(trainer, model),
|
||||||
call.on_epoch_start(trainer, model),
|
call.on_epoch_start(trainer, model),
|
||||||
|
@ -166,13 +166,10 @@ def test_trainer_callback_hook_system_validate(tmpdir):
|
||||||
call.on_validation_epoch_end(trainer, model),
|
call.on_validation_epoch_end(trainer, model),
|
||||||
call.on_epoch_end(trainer, model),
|
call.on_epoch_end(trainer, model),
|
||||||
call.on_validation_end(trainer, model),
|
call.on_validation_end(trainer, model),
|
||||||
call.teardown(trainer, model, 'validate'),
|
call.teardown(trainer, model, stage='validate'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# TODO: add callback tests for predict and tune
|
|
||||||
|
|
||||||
|
|
||||||
def test_callbacks_configured_in_model(tmpdir):
|
def test_callbacks_configured_in_model(tmpdir):
|
||||||
""" Test the callback system with callbacks added through the model hook. """
|
""" Test the callback system with callbacks added through the model hook. """
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything, Trainer
|
from pytorch_lightning import seed_everything, Trainer
|
||||||
from pytorch_lightning.callbacks import Callback, LambdaCallback
|
from pytorch_lightning.callbacks import Callback, LambdaCallback
|
||||||
|
@ -28,9 +29,13 @@ def test_lambda_call(tmpdir):
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
|
|
||||||
checker = set()
|
checker = set()
|
||||||
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
|
|
||||||
hooks_args = {h: (lambda x: lambda *_: checker.add(x))(h) for h in hooks}
|
def call(hook, *_, **__):
|
||||||
hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint")
|
checker.add(hook)
|
||||||
|
|
||||||
|
hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)}
|
||||||
|
hooks_args = {h: partial(call, h) for h in hooks}
|
||||||
|
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')]
|
||||||
|
|
||||||
model = CustomModel()
|
model = CustomModel()
|
||||||
|
|
||||||
|
@ -59,4 +64,4 @@ def test_lambda_call(tmpdir):
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
trainer.predict(model)
|
trainer.predict(model)
|
||||||
|
|
||||||
assert checker == set(hooks)
|
assert checker == hooks
|
||||||
|
|
|
@ -20,7 +20,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from pytorch_lightning import __version__, LightningDataModule, LightningModule, Trainer
|
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer
|
||||||
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
|
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
|
|
||||||
|
@ -235,6 +235,22 @@ def get_members(cls):
|
||||||
return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')}
|
return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')}
|
||||||
|
|
||||||
|
|
||||||
|
class HookedCallback(Callback):
|
||||||
|
|
||||||
|
def __init__(self, called):
|
||||||
|
|
||||||
|
def call(hook, *args, **kwargs):
|
||||||
|
d = {'name': f'Callback.{hook}'}
|
||||||
|
if args:
|
||||||
|
d['args'] = args
|
||||||
|
if kwargs:
|
||||||
|
d['kwargs'] = kwargs
|
||||||
|
called.append(d)
|
||||||
|
|
||||||
|
for h in get_members(Callback):
|
||||||
|
setattr(self, h, partial(call, h))
|
||||||
|
|
||||||
|
|
||||||
class HookedModel(BoringModel):
|
class HookedModel(BoringModel):
|
||||||
|
|
||||||
def __init__(self, called):
|
def __init__(self, called):
|
||||||
|
@ -246,7 +262,12 @@ class HookedModel(BoringModel):
|
||||||
|
|
||||||
def call(hook, fn, *args, **kwargs):
|
def call(hook, fn, *args, **kwargs):
|
||||||
out = fn(*args, **kwargs)
|
out = fn(*args, **kwargs)
|
||||||
called.append(hook)
|
d = {'name': hook}
|
||||||
|
if args:
|
||||||
|
d['args'] = args
|
||||||
|
if kwargs:
|
||||||
|
d['kwargs'] = kwargs
|
||||||
|
called.append(d)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
for h in pl_module_hooks:
|
for h in pl_module_hooks:
|
||||||
|
@ -305,6 +326,25 @@ class HookedModel(BoringModel):
|
||||||
'on_test_batch_end',
|
'on_test_batch_end',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _predict_batch(trainer, model, batches):
|
||||||
|
out = []
|
||||||
|
for i in range(batches):
|
||||||
|
out.extend([
|
||||||
|
# TODO: `{,Callback}.on_batch_{start,end}`
|
||||||
|
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
|
||||||
|
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
|
||||||
|
dict(name='on_before_batch_transfer', args=(ANY, None)),
|
||||||
|
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
|
||||||
|
dict(name='on_after_batch_transfer', args=(ANY, None)),
|
||||||
|
dict(name='forward', args=(ANY, )),
|
||||||
|
dict(name='predict_step', args=(ANY, i)),
|
||||||
|
# TODO: `predict_step_end`
|
||||||
|
dict(name='Callback.on_predict_batch_end', args=(trainer, model, ANY, ANY, i, 0)),
|
||||||
|
dict(name='on_predict_batch_end', args=(ANY, ANY, i, 0)),
|
||||||
|
])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def test_trainer_model_hook_system_fit(tmpdir):
|
def test_trainer_model_hook_system_fit(tmpdir):
|
||||||
called = []
|
called = []
|
||||||
|
@ -374,6 +414,7 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
||||||
'on_fit_end',
|
'on_fit_end',
|
||||||
'teardown',
|
'teardown',
|
||||||
]
|
]
|
||||||
|
called = [c['name'] for c in called]
|
||||||
assert called == expected
|
assert called == expected
|
||||||
|
|
||||||
|
|
||||||
|
@ -418,6 +459,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
|
||||||
'on_fit_end',
|
'on_fit_end',
|
||||||
'teardown',
|
'teardown',
|
||||||
]
|
]
|
||||||
|
called = [c['name'] for c in called]
|
||||||
assert called == expected
|
assert called == expected
|
||||||
|
|
||||||
|
|
||||||
|
@ -459,6 +501,7 @@ def test_trainer_model_hook_system_validate(tmpdir, batches):
|
||||||
*(hooks if batches else []),
|
*(hooks if batches else []),
|
||||||
'teardown',
|
'teardown',
|
||||||
]
|
]
|
||||||
|
called = [c['name'] for c in called]
|
||||||
assert called == expected
|
assert called == expected
|
||||||
|
|
||||||
|
|
||||||
|
@ -497,9 +540,62 @@ def test_trainer_model_hook_system_test(tmpdir):
|
||||||
'on_test_model_train',
|
'on_test_model_train',
|
||||||
'teardown',
|
'teardown',
|
||||||
]
|
]
|
||||||
|
called = [c['name'] for c in called]
|
||||||
assert called == expected
|
assert called == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_trainer_model_hook_system_predict(tmpdir):
|
||||||
|
called = []
|
||||||
|
model = HookedModel(called)
|
||||||
|
callback = HookedCallback(called)
|
||||||
|
batches = 2
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_predict_batches=batches,
|
||||||
|
progress_bar_refresh_rate=0,
|
||||||
|
callbacks=[callback],
|
||||||
|
)
|
||||||
|
assert called == [
|
||||||
|
dict(name='Callback.on_init_start', args=(trainer, )),
|
||||||
|
dict(name='Callback.on_init_end', args=(trainer, )),
|
||||||
|
]
|
||||||
|
trainer.predict(model)
|
||||||
|
expected = [
|
||||||
|
dict(name='Callback.on_init_start', args=(trainer, )),
|
||||||
|
dict(name='Callback.on_init_end', args=(trainer, )),
|
||||||
|
dict(name='prepare_data'),
|
||||||
|
dict(name='configure_callbacks'),
|
||||||
|
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
|
||||||
|
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='predict')),
|
||||||
|
dict(name='setup', kwargs=dict(stage='predict')),
|
||||||
|
dict(name='configure_sharded_model'),
|
||||||
|
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
|
||||||
|
dict(name='on_predict_dataloader'),
|
||||||
|
dict(name='predict_dataloader'),
|
||||||
|
dict(name='train', args=(False, )),
|
||||||
|
dict(name='on_predict_model_eval'),
|
||||||
|
dict(name='zero_grad'),
|
||||||
|
dict(name='Callback.on_predict_start', args=(trainer, model)),
|
||||||
|
dict(name='on_predict_start'),
|
||||||
|
# TODO: `{,Callback}.on_epoch_{start,end}`
|
||||||
|
dict(name='Callback.on_predict_epoch_start', args=(trainer, model)),
|
||||||
|
dict(name='on_predict_epoch_start'),
|
||||||
|
*model._predict_batch(trainer, model, batches),
|
||||||
|
# TODO: `predict_epoch_end`
|
||||||
|
dict(name='Callback.on_predict_epoch_end', args=(trainer, model, [[ANY] * batches])),
|
||||||
|
dict(name='on_predict_epoch_end', args=([[ANY] * batches], )),
|
||||||
|
dict(name='Callback.on_predict_end', args=(trainer, model)),
|
||||||
|
dict(name='on_predict_end'),
|
||||||
|
# TODO: `on_predict_model_train`
|
||||||
|
dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='predict')),
|
||||||
|
dict(name='teardown', kwargs=dict(stage='predict')),
|
||||||
|
]
|
||||||
|
assert called == expected
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add test for tune
|
||||||
|
|
||||||
|
|
||||||
def test_hooks_with_different_argument_names(tmpdir):
|
def test_hooks_with_different_argument_names(tmpdir):
|
||||||
"""
|
"""
|
||||||
Test that argument names can be anything in the hooks
|
Test that argument names can be anything in the hooks
|
||||||
|
|
Loading…
Reference in New Issue