Improve `LightningModule` hook tests (#7944)

This commit is contained in:
Carlos Mocholí 2021-06-14 18:16:42 +02:00 committed by GitHub
parent 3a0ed02bd4
commit 03e7bdf8d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 129 additions and 240 deletions

View File

@ -105,7 +105,7 @@ class MNIST(Dataset):
raise RuntimeError('Dataset not found.')
def _download(self, data_folder: str) -> None:
os.makedirs(data_folder)
os.makedirs(data_folder, exist_ok=True)
for url in self.RESOURCES:
logging.info(f'Downloading {url}')
fpath = os.path.join(data_folder, os.path.basename(url))

View File

@ -20,7 +20,7 @@ import pytest
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import __version__, LightningDataModule, Trainer
from pytorch_lightning import __version__, LightningDataModule, LightningModule, Trainer
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@ -231,17 +231,46 @@ def test_transfer_batch_hook_ddp(tmpdir):
trainer.fit(model)
def get_members(cls):
return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')}
class HookedModel(BoringModel):
def __init__(self):
def __init__(self, called):
super().__init__()
self.called = []
self.train_batch = [
pl_module_hooks = get_members(LightningModule)
# remove most `nn.Module` hooks
module_hooks = get_members(torch.nn.Module)
pl_module_hooks.difference_update(module_hooks - {'forward', 'zero_grad', 'train'})
def call(hook, fn, *args, **kwargs):
out = fn(*args, **kwargs)
called.append(hook)
return out
for h in pl_module_hooks:
attr = getattr(self, h)
setattr(self, h, partial(call, h, attr))
def validation_epoch_end(self, *args, **kwargs):
# `BoringModel` does not have a return for `validation_step_end` so this would fail
pass
def test_epoch_end(self, *args, **kwargs):
# `BoringModel` does not have a return for `test_step_end` so this would fail
pass
@staticmethod
def _train_batch():
return [
'on_train_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'forward',
'training_step',
'training_step_end',
'on_before_zero_grad',
'optimizer_zero_grad',
'backward',
@ -249,209 +278,37 @@ class HookedModel(BoringModel):
'optimizer_step',
'on_train_batch_end',
]
self.val_batch = [
@staticmethod
def _val_batch():
return [
'on_validation_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'forward',
'validation_step',
'validation_step_end',
'on_validation_batch_end',
]
def prepare_data(self):
self.called.append("prepare_data")
return super().prepare_data()
def configure_callbacks(self):
self.called.append("configure_callbacks")
return super().configure_callbacks()
def configure_optimizers(self):
self.called.append("configure_optimizers")
return super().configure_optimizers()
def training_step(self, *args, **kwargs):
self.called.append("training_step")
return super().training_step(*args, **kwargs)
def optimizer_zero_grad(self, *args, **kwargs):
self.called.append("optimizer_zero_grad")
super().optimizer_zero_grad(*args, **kwargs)
def training_epoch_end(self, *args, **kwargs):
self.called.append("training_epoch_end")
super().training_epoch_end(*args, **kwargs)
def backward(self, *args, **kwargs):
self.called.append("backward")
super().backward(*args, **kwargs)
def on_after_backward(self):
self.called.append("on_after_backward")
super().on_after_backward()
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
self.called.append("optimizer_step") # append after as closure calls other methods
def validation_epoch_end(self, *args, **kwargs):
self.called.append("validation_epoch_end")
super().validation_epoch_end(*args, **kwargs)
def on_before_zero_grad(self, *args, **kwargs):
self.called.append("on_before_zero_grad")
super().on_before_zero_grad(*args, **kwargs)
def on_epoch_start(self):
self.called.append("on_epoch_start")
super().on_epoch_start()
def on_epoch_end(self):
self.called.append("on_epoch_end")
super().on_epoch_end()
def on_fit_start(self):
self.called.append("on_fit_start")
super().on_fit_start()
def on_fit_end(self):
self.called.append("on_fit_end")
super().on_fit_end()
def on_hpc_load(self, *args, **kwargs):
self.called.append("on_hpc_load")
super().on_hpc_load(*args, **kwargs)
def on_hpc_save(self, *args, **kwargs):
self.called.append("on_hpc_save")
super().on_hpc_save(*args, **kwargs)
def on_load_checkpoint(self, *args, **kwargs):
self.called.append("on_load_checkpoint")
super().on_load_checkpoint(*args, **kwargs)
def on_save_checkpoint(self, *args, **kwargs):
self.called.append("on_save_checkpoint")
super().on_save_checkpoint(*args, **kwargs)
def on_pretrain_routine_start(self):
self.called.append("on_pretrain_routine_start")
super().on_pretrain_routine_start()
def on_pretrain_routine_end(self):
self.called.append("on_pretrain_routine_end")
super().on_pretrain_routine_end()
def on_train_start(self):
self.called.append("on_train_start")
super().on_train_start()
def on_train_end(self):
self.called.append("on_train_end")
super().on_train_end()
def on_before_batch_transfer(self, *args, **kwargs):
self.called.append("on_before_batch_transfer")
return super().on_before_batch_transfer(*args, **kwargs)
def transfer_batch_to_device(self, *args, **kwargs):
self.called.append("transfer_batch_to_device")
return super().transfer_batch_to_device(*args, **kwargs)
def on_after_batch_transfer(self, *args, **kwargs):
self.called.append("on_after_batch_transfer")
return super().on_after_batch_transfer(*args, **kwargs)
def on_train_batch_start(self, *args, **kwargs):
self.called.append("on_train_batch_start")
super().on_train_batch_start(*args, **kwargs)
def on_train_batch_end(self, *args, **kwargs):
self.called.append("on_train_batch_end")
super().on_train_batch_end(*args, **kwargs)
def on_train_epoch_start(self):
self.called.append("on_train_epoch_start")
super().on_train_epoch_start()
def on_train_epoch_end(self):
self.called.append("on_train_epoch_end")
super().on_train_epoch_end()
def on_validation_start(self):
self.called.append("on_validation_start")
super().on_validation_start()
def on_validation_end(self):
self.called.append("on_validation_end")
super().on_validation_end()
def on_validation_batch_start(self, *args, **kwargs):
self.called.append("on_validation_batch_start")
super().on_validation_batch_start(*args, **kwargs)
def on_validation_batch_end(self, *args, **kwargs):
self.called.append("on_validation_batch_end")
super().on_validation_batch_end(*args, **kwargs)
def on_validation_epoch_start(self):
self.called.append("on_validation_epoch_start")
super().on_validation_epoch_start()
def on_validation_epoch_end(self, *args, **kwargs):
self.called.append("on_validation_epoch_end")
super().on_validation_epoch_end(*args, **kwargs)
def on_test_start(self):
self.called.append("on_test_start")
super().on_test_start()
def on_test_batch_start(self, *args, **kwargs):
self.called.append("on_test_batch_start")
super().on_test_batch_start(*args, **kwargs)
def on_test_batch_end(self, *args, **kwargs):
self.called.append("on_test_batch_end")
super().on_test_batch_end(*args, **kwargs)
def on_test_epoch_start(self):
self.called.append("on_test_epoch_start")
super().on_test_epoch_start()
def on_test_epoch_end(self, *args, **kwargs):
self.called.append("on_test_epoch_end")
super().on_test_epoch_end(*args, **kwargs)
def on_validation_model_eval(self):
self.called.append("on_validation_model_eval")
super().on_validation_model_eval()
def on_validation_model_train(self):
self.called.append("on_validation_model_train")
super().on_validation_model_train()
def on_test_model_eval(self):
self.called.append("on_test_model_eval")
super().on_test_model_eval()
def on_test_model_train(self):
self.called.append("on_test_model_train")
super().on_test_model_train()
def on_test_end(self):
self.called.append("on_test_end")
super().on_test_end()
def setup(self, stage=None):
self.called.append(f"setup_{stage}")
super().setup(stage=stage)
def teardown(self, stage=None):
self.called.append(f"teardown_{stage}")
super().teardown(stage)
@staticmethod
def _test_batch():
return [
'on_test_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'forward',
'test_step',
'test_step_end',
'on_test_batch_end',
]
def test_trainer_model_hook_system_fit(tmpdir):
model = HookedModel()
called = []
model = HookedModel(called)
train_batches = 2
val_batches = 2
trainer = Trainer(
@ -462,53 +319,67 @@ def test_trainer_model_hook_system_fit(tmpdir):
progress_bar_refresh_rate=0,
weights_summary=None,
)
assert model.called == []
assert called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup_fit',
'setup',
'configure_sharded_model',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_val_dataloader',
'val_dataloader',
'train', # eval() == train(False)
'on_validation_model_eval',
'zero_grad',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
*(model.val_batch * val_batches),
*(HookedModel._val_batch() * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_validation_end',
'train',
'on_validation_model_train',
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
'train',
'on_train_dataloader',
'train_dataloader',
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
*(model.train_batch * train_batches),
*(HookedModel._train_batch() * train_batches),
'train', # eval() == train(False)
'on_validation_model_eval',
'zero_grad',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
*(model.val_batch * val_batches),
*(HookedModel._val_batch() * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_save_checkpoint',
'on_validation_end',
'train',
'on_validation_model_train',
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_train_end',
'on_fit_end',
'teardown_fit',
'teardown',
]
assert model.called == expected
assert called == expected
def test_trainer_model_hook_system_fit_no_val(tmpdir):
model = HookedModel()
called = []
model = HookedModel(called)
train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
@ -518,96 +389,115 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir):
progress_bar_refresh_rate=0,
weights_summary=None,
)
assert model.called == []
assert called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup_fit',
'setup',
'configure_sharded_model',
'configure_optimizers',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'train',
'on_train_dataloader',
'train_dataloader',
# even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
'on_val_dataloader',
'val_dataloader',
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
*(model.train_batch * train_batches),
*(HookedModel._train_batch() * train_batches),
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_save_checkpoint', # from train epoch end
'on_train_end',
'on_fit_end',
'teardown_fit',
'teardown',
]
assert model.called == expected
assert called == expected
def test_trainer_model_hook_system_validate(tmpdir):
model = HookedModel()
@pytest.mark.parametrize('batches', (0, 2))
def test_trainer_model_hook_system_validate(tmpdir, batches):
called = []
model = HookedModel(called)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=1,
limit_val_batches=batches,
progress_bar_refresh_rate=0,
weights_summary=None,
)
assert model.called == []
assert called == []
trainer.validate(model, verbose=False)
expected = [
'prepare_data',
'configure_callbacks',
'setup_validate',
hooks = [
'train', # eval() == train(False)
'on_validation_model_eval',
'zero_grad',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_validation_batch_end',
*(HookedModel._val_batch() * batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_validation_end',
'on_validation_model_train',
'teardown_validate',
'train',
'on_validation_model_train'
]
assert model.called == expected
expected = [
'prepare_data',
'configure_callbacks',
'setup',
'configure_sharded_model',
'on_val_dataloader',
'val_dataloader',
*(hooks if batches else []),
'teardown',
]
assert called == expected
def test_trainer_model_hook_system_test(tmpdir):
model = HookedModel()
called = []
model = HookedModel(called)
batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_test_batches=1,
limit_test_batches=batches,
progress_bar_refresh_rate=0,
weights_summary=None,
)
assert model.called == []
assert called == []
trainer.test(model, verbose=False)
expected = [
'prepare_data',
'configure_callbacks',
'setup_test',
'setup',
'configure_sharded_model',
'on_test_dataloader',
'test_dataloader',
'train', # eval() == train(False)
'on_test_model_eval',
'zero_grad',
'on_test_start',
'on_epoch_start',
'on_test_epoch_start',
'on_test_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_test_batch_end',
*(HookedModel._test_batch() * batches),
'test_epoch_end',
'on_test_epoch_end',
'on_epoch_end',
'on_test_end',
'train',
'on_test_model_train',
'teardown_test',
'teardown',
]
assert model.called == expected
assert called == expected
def test_hooks_with_different_argument_names(tmpdir):
@ -681,8 +571,7 @@ def test_trainer_datamodule_hook_system(tmpdir):
called.append(d)
return out
hooks = {h for h, _ in getmembers(LightningDataModule, predicate=isfunction)}
for h in hooks:
for h in get_members(LightningDataModule):
attr = getattr(self, h)
setattr(self, h, partial(call, h, attr))