Improve `LightningModule` hook tests (#7944)
This commit is contained in:
parent
3a0ed02bd4
commit
03e7bdf8d5
|
@ -105,7 +105,7 @@ class MNIST(Dataset):
|
|||
raise RuntimeError('Dataset not found.')
|
||||
|
||||
def _download(self, data_folder: str) -> None:
|
||||
os.makedirs(data_folder)
|
||||
os.makedirs(data_folder, exist_ok=True)
|
||||
for url in self.RESOURCES:
|
||||
logging.info(f'Downloading {url}')
|
||||
fpath = os.path.join(data_folder, os.path.basename(url))
|
||||
|
|
|
@ -20,7 +20,7 @@ import pytest
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import __version__, LightningDataModule, Trainer
|
||||
from pytorch_lightning import __version__, LightningDataModule, LightningModule, Trainer
|
||||
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -231,17 +231,46 @@ def test_transfer_batch_hook_ddp(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def get_members(cls):
|
||||
return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')}
|
||||
|
||||
|
||||
class HookedModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, called):
|
||||
super().__init__()
|
||||
self.called = []
|
||||
self.train_batch = [
|
||||
pl_module_hooks = get_members(LightningModule)
|
||||
# remove most `nn.Module` hooks
|
||||
module_hooks = get_members(torch.nn.Module)
|
||||
pl_module_hooks.difference_update(module_hooks - {'forward', 'zero_grad', 'train'})
|
||||
|
||||
def call(hook, fn, *args, **kwargs):
|
||||
out = fn(*args, **kwargs)
|
||||
called.append(hook)
|
||||
return out
|
||||
|
||||
for h in pl_module_hooks:
|
||||
attr = getattr(self, h)
|
||||
setattr(self, h, partial(call, h, attr))
|
||||
|
||||
def validation_epoch_end(self, *args, **kwargs):
|
||||
# `BoringModel` does not have a return for `validation_step_end` so this would fail
|
||||
pass
|
||||
|
||||
def test_epoch_end(self, *args, **kwargs):
|
||||
# `BoringModel` does not have a return for `test_step_end` so this would fail
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _train_batch():
|
||||
return [
|
||||
'on_train_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'forward',
|
||||
'training_step',
|
||||
'training_step_end',
|
||||
'on_before_zero_grad',
|
||||
'optimizer_zero_grad',
|
||||
'backward',
|
||||
|
@ -249,209 +278,37 @@ class HookedModel(BoringModel):
|
|||
'optimizer_step',
|
||||
'on_train_batch_end',
|
||||
]
|
||||
self.val_batch = [
|
||||
|
||||
@staticmethod
|
||||
def _val_batch():
|
||||
return [
|
||||
'on_validation_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'forward',
|
||||
'validation_step',
|
||||
'validation_step_end',
|
||||
'on_validation_batch_end',
|
||||
]
|
||||
|
||||
def prepare_data(self):
|
||||
self.called.append("prepare_data")
|
||||
return super().prepare_data()
|
||||
|
||||
def configure_callbacks(self):
|
||||
self.called.append("configure_callbacks")
|
||||
return super().configure_callbacks()
|
||||
|
||||
def configure_optimizers(self):
|
||||
self.called.append("configure_optimizers")
|
||||
return super().configure_optimizers()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
self.called.append("training_step")
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
def optimizer_zero_grad(self, *args, **kwargs):
|
||||
self.called.append("optimizer_zero_grad")
|
||||
super().optimizer_zero_grad(*args, **kwargs)
|
||||
|
||||
def training_epoch_end(self, *args, **kwargs):
|
||||
self.called.append("training_epoch_end")
|
||||
super().training_epoch_end(*args, **kwargs)
|
||||
|
||||
def backward(self, *args, **kwargs):
|
||||
self.called.append("backward")
|
||||
super().backward(*args, **kwargs)
|
||||
|
||||
def on_after_backward(self):
|
||||
self.called.append("on_after_backward")
|
||||
super().on_after_backward()
|
||||
|
||||
def optimizer_step(self, *args, **kwargs):
|
||||
super().optimizer_step(*args, **kwargs)
|
||||
self.called.append("optimizer_step") # append after as closure calls other methods
|
||||
|
||||
def validation_epoch_end(self, *args, **kwargs):
|
||||
self.called.append("validation_epoch_end")
|
||||
super().validation_epoch_end(*args, **kwargs)
|
||||
|
||||
def on_before_zero_grad(self, *args, **kwargs):
|
||||
self.called.append("on_before_zero_grad")
|
||||
super().on_before_zero_grad(*args, **kwargs)
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.called.append("on_epoch_start")
|
||||
super().on_epoch_start()
|
||||
|
||||
def on_epoch_end(self):
|
||||
self.called.append("on_epoch_end")
|
||||
super().on_epoch_end()
|
||||
|
||||
def on_fit_start(self):
|
||||
self.called.append("on_fit_start")
|
||||
super().on_fit_start()
|
||||
|
||||
def on_fit_end(self):
|
||||
self.called.append("on_fit_end")
|
||||
super().on_fit_end()
|
||||
|
||||
def on_hpc_load(self, *args, **kwargs):
|
||||
self.called.append("on_hpc_load")
|
||||
super().on_hpc_load(*args, **kwargs)
|
||||
|
||||
def on_hpc_save(self, *args, **kwargs):
|
||||
self.called.append("on_hpc_save")
|
||||
super().on_hpc_save(*args, **kwargs)
|
||||
|
||||
def on_load_checkpoint(self, *args, **kwargs):
|
||||
self.called.append("on_load_checkpoint")
|
||||
super().on_load_checkpoint(*args, **kwargs)
|
||||
|
||||
def on_save_checkpoint(self, *args, **kwargs):
|
||||
self.called.append("on_save_checkpoint")
|
||||
super().on_save_checkpoint(*args, **kwargs)
|
||||
|
||||
def on_pretrain_routine_start(self):
|
||||
self.called.append("on_pretrain_routine_start")
|
||||
super().on_pretrain_routine_start()
|
||||
|
||||
def on_pretrain_routine_end(self):
|
||||
self.called.append("on_pretrain_routine_end")
|
||||
super().on_pretrain_routine_end()
|
||||
|
||||
def on_train_start(self):
|
||||
self.called.append("on_train_start")
|
||||
super().on_train_start()
|
||||
|
||||
def on_train_end(self):
|
||||
self.called.append("on_train_end")
|
||||
super().on_train_end()
|
||||
|
||||
def on_before_batch_transfer(self, *args, **kwargs):
|
||||
self.called.append("on_before_batch_transfer")
|
||||
return super().on_before_batch_transfer(*args, **kwargs)
|
||||
|
||||
def transfer_batch_to_device(self, *args, **kwargs):
|
||||
self.called.append("transfer_batch_to_device")
|
||||
return super().transfer_batch_to_device(*args, **kwargs)
|
||||
|
||||
def on_after_batch_transfer(self, *args, **kwargs):
|
||||
self.called.append("on_after_batch_transfer")
|
||||
return super().on_after_batch_transfer(*args, **kwargs)
|
||||
|
||||
def on_train_batch_start(self, *args, **kwargs):
|
||||
self.called.append("on_train_batch_start")
|
||||
super().on_train_batch_start(*args, **kwargs)
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
self.called.append("on_train_batch_end")
|
||||
super().on_train_batch_end(*args, **kwargs)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
self.called.append("on_train_epoch_start")
|
||||
super().on_train_epoch_start()
|
||||
|
||||
def on_train_epoch_end(self):
|
||||
self.called.append("on_train_epoch_end")
|
||||
super().on_train_epoch_end()
|
||||
|
||||
def on_validation_start(self):
|
||||
self.called.append("on_validation_start")
|
||||
super().on_validation_start()
|
||||
|
||||
def on_validation_end(self):
|
||||
self.called.append("on_validation_end")
|
||||
super().on_validation_end()
|
||||
|
||||
def on_validation_batch_start(self, *args, **kwargs):
|
||||
self.called.append("on_validation_batch_start")
|
||||
super().on_validation_batch_start(*args, **kwargs)
|
||||
|
||||
def on_validation_batch_end(self, *args, **kwargs):
|
||||
self.called.append("on_validation_batch_end")
|
||||
super().on_validation_batch_end(*args, **kwargs)
|
||||
|
||||
def on_validation_epoch_start(self):
|
||||
self.called.append("on_validation_epoch_start")
|
||||
super().on_validation_epoch_start()
|
||||
|
||||
def on_validation_epoch_end(self, *args, **kwargs):
|
||||
self.called.append("on_validation_epoch_end")
|
||||
super().on_validation_epoch_end(*args, **kwargs)
|
||||
|
||||
def on_test_start(self):
|
||||
self.called.append("on_test_start")
|
||||
super().on_test_start()
|
||||
|
||||
def on_test_batch_start(self, *args, **kwargs):
|
||||
self.called.append("on_test_batch_start")
|
||||
super().on_test_batch_start(*args, **kwargs)
|
||||
|
||||
def on_test_batch_end(self, *args, **kwargs):
|
||||
self.called.append("on_test_batch_end")
|
||||
super().on_test_batch_end(*args, **kwargs)
|
||||
|
||||
def on_test_epoch_start(self):
|
||||
self.called.append("on_test_epoch_start")
|
||||
super().on_test_epoch_start()
|
||||
|
||||
def on_test_epoch_end(self, *args, **kwargs):
|
||||
self.called.append("on_test_epoch_end")
|
||||
super().on_test_epoch_end(*args, **kwargs)
|
||||
|
||||
def on_validation_model_eval(self):
|
||||
self.called.append("on_validation_model_eval")
|
||||
super().on_validation_model_eval()
|
||||
|
||||
def on_validation_model_train(self):
|
||||
self.called.append("on_validation_model_train")
|
||||
super().on_validation_model_train()
|
||||
|
||||
def on_test_model_eval(self):
|
||||
self.called.append("on_test_model_eval")
|
||||
super().on_test_model_eval()
|
||||
|
||||
def on_test_model_train(self):
|
||||
self.called.append("on_test_model_train")
|
||||
super().on_test_model_train()
|
||||
|
||||
def on_test_end(self):
|
||||
self.called.append("on_test_end")
|
||||
super().on_test_end()
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.called.append(f"setup_{stage}")
|
||||
super().setup(stage=stage)
|
||||
|
||||
def teardown(self, stage=None):
|
||||
self.called.append(f"teardown_{stage}")
|
||||
super().teardown(stage)
|
||||
@staticmethod
|
||||
def _test_batch():
|
||||
return [
|
||||
'on_test_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'forward',
|
||||
'test_step',
|
||||
'test_step_end',
|
||||
'on_test_batch_end',
|
||||
]
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_fit(tmpdir):
|
||||
model = HookedModel()
|
||||
called = []
|
||||
model = HookedModel(called)
|
||||
train_batches = 2
|
||||
val_batches = 2
|
||||
trainer = Trainer(
|
||||
|
@ -462,53 +319,67 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
assert model.called == []
|
||||
assert called == []
|
||||
trainer.fit(model)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_fit',
|
||||
'setup',
|
||||
'configure_sharded_model',
|
||||
'configure_optimizers',
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
'on_val_dataloader',
|
||||
'val_dataloader',
|
||||
'train', # eval() == train(False)
|
||||
'on_validation_model_eval',
|
||||
'zero_grad',
|
||||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
'on_validation_epoch_start',
|
||||
*(model.val_batch * val_batches),
|
||||
*(HookedModel._val_batch() * val_batches),
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_validation_end',
|
||||
'train',
|
||||
'on_validation_model_train',
|
||||
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
|
||||
'train',
|
||||
'on_train_dataloader',
|
||||
'train_dataloader',
|
||||
'on_train_start',
|
||||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
*(model.train_batch * train_batches),
|
||||
*(HookedModel._train_batch() * train_batches),
|
||||
'train', # eval() == train(False)
|
||||
'on_validation_model_eval',
|
||||
'zero_grad',
|
||||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
'on_validation_epoch_start',
|
||||
*(model.val_batch * val_batches),
|
||||
*(HookedModel._val_batch() * val_batches),
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_save_checkpoint',
|
||||
'on_validation_end',
|
||||
'train',
|
||||
'on_validation_model_train',
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_train_end',
|
||||
'on_fit_end',
|
||||
'teardown_fit',
|
||||
'teardown',
|
||||
]
|
||||
assert model.called == expected
|
||||
assert called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_fit_no_val(tmpdir):
|
||||
model = HookedModel()
|
||||
called = []
|
||||
model = HookedModel(called)
|
||||
train_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -518,96 +389,115 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
assert model.called == []
|
||||
assert called == []
|
||||
trainer.fit(model)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_fit',
|
||||
'setup',
|
||||
'configure_sharded_model',
|
||||
'configure_optimizers',
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
'train',
|
||||
'on_train_dataloader',
|
||||
'train_dataloader',
|
||||
# even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
|
||||
'on_val_dataloader',
|
||||
'val_dataloader',
|
||||
'on_train_start',
|
||||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
*(model.train_batch * train_batches),
|
||||
*(HookedModel._train_batch() * train_batches),
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_save_checkpoint', # from train epoch end
|
||||
'on_train_end',
|
||||
'on_fit_end',
|
||||
'teardown_fit',
|
||||
'teardown',
|
||||
]
|
||||
assert model.called == expected
|
||||
assert called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_validate(tmpdir):
|
||||
model = HookedModel()
|
||||
@pytest.mark.parametrize('batches', (0, 2))
|
||||
def test_trainer_model_hook_system_validate(tmpdir, batches):
|
||||
called = []
|
||||
model = HookedModel(called)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=1,
|
||||
limit_val_batches=batches,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
assert model.called == []
|
||||
assert called == []
|
||||
trainer.validate(model, verbose=False)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_validate',
|
||||
hooks = [
|
||||
'train', # eval() == train(False)
|
||||
'on_validation_model_eval',
|
||||
'zero_grad',
|
||||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
'on_validation_epoch_start',
|
||||
'on_validation_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_validation_batch_end',
|
||||
*(HookedModel._val_batch() * batches),
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_validation_end',
|
||||
'on_validation_model_train',
|
||||
'teardown_validate',
|
||||
'train',
|
||||
'on_validation_model_train'
|
||||
]
|
||||
assert model.called == expected
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup',
|
||||
'configure_sharded_model',
|
||||
'on_val_dataloader',
|
||||
'val_dataloader',
|
||||
*(hooks if batches else []),
|
||||
'teardown',
|
||||
]
|
||||
assert called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_test(tmpdir):
|
||||
model = HookedModel()
|
||||
called = []
|
||||
model = HookedModel(called)
|
||||
batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_test_batches=1,
|
||||
limit_test_batches=batches,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
assert model.called == []
|
||||
assert called == []
|
||||
trainer.test(model, verbose=False)
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'configure_callbacks',
|
||||
'setup_test',
|
||||
'setup',
|
||||
'configure_sharded_model',
|
||||
'on_test_dataloader',
|
||||
'test_dataloader',
|
||||
'train', # eval() == train(False)
|
||||
'on_test_model_eval',
|
||||
'zero_grad',
|
||||
'on_test_start',
|
||||
'on_epoch_start',
|
||||
'on_test_epoch_start',
|
||||
'on_test_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_test_batch_end',
|
||||
*(HookedModel._test_batch() * batches),
|
||||
'test_epoch_end',
|
||||
'on_test_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_test_end',
|
||||
'train',
|
||||
'on_test_model_train',
|
||||
'teardown_test',
|
||||
'teardown',
|
||||
]
|
||||
assert model.called == expected
|
||||
assert called == expected
|
||||
|
||||
|
||||
def test_hooks_with_different_argument_names(tmpdir):
|
||||
|
@ -681,8 +571,7 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
called.append(d)
|
||||
return out
|
||||
|
||||
hooks = {h for h, _ in getmembers(LightningDataModule, predicate=isfunction)}
|
||||
for h in hooks:
|
||||
for h in get_members(LightningDataModule):
|
||||
attr = getattr(self, h)
|
||||
setattr(self, h, partial(call, h, attr))
|
||||
|
||||
|
|
Loading…
Reference in New Issue