diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 77035796ca..9fadd947ac 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -105,7 +105,7 @@ class MNIST(Dataset): raise RuntimeError('Dataset not found.') def _download(self, data_folder: str) -> None: - os.makedirs(data_folder) + os.makedirs(data_folder, exist_ok=True) for url in self.RESOURCES: logging.info(f'Downloading {url}') fpath = os.path.join(data_folder, os.path.basename(url)) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 9904e43be6..6413ca8c93 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -20,7 +20,7 @@ import pytest import torch from torch.utils.data import DataLoader -from pytorch_lightning import __version__, LightningDataModule, Trainer +from pytorch_lightning import __version__, LightningDataModule, LightningModule, Trainer from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -231,17 +231,46 @@ def test_transfer_batch_hook_ddp(tmpdir): trainer.fit(model) +def get_members(cls): + return {h for h, _ in getmembers(cls, predicate=isfunction) if not h.startswith('_')} + + class HookedModel(BoringModel): - def __init__(self): + def __init__(self, called): super().__init__() - self.called = [] - self.train_batch = [ + pl_module_hooks = get_members(LightningModule) + # remove most `nn.Module` hooks + module_hooks = get_members(torch.nn.Module) + pl_module_hooks.difference_update(module_hooks - {'forward', 'zero_grad', 'train'}) + + def call(hook, fn, *args, **kwargs): + out = fn(*args, **kwargs) + called.append(hook) + return out + + for h in pl_module_hooks: + attr = getattr(self, h) + setattr(self, h, partial(call, h, attr)) + + def validation_epoch_end(self, *args, **kwargs): + # `BoringModel` does not have a return for `validation_step_end` so this would fail + pass + + def test_epoch_end(self, *args, **kwargs): + # `BoringModel` does not have a return for `test_step_end` so this would fail + pass + + @staticmethod + def _train_batch(): + return [ 'on_train_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', + 'forward', 'training_step', + 'training_step_end', 'on_before_zero_grad', 'optimizer_zero_grad', 'backward', @@ -249,209 +278,37 @@ class HookedModel(BoringModel): 'optimizer_step', 'on_train_batch_end', ] - self.val_batch = [ + + @staticmethod + def _val_batch(): + return [ 'on_validation_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', + 'forward', + 'validation_step', + 'validation_step_end', 'on_validation_batch_end', ] - def prepare_data(self): - self.called.append("prepare_data") - return super().prepare_data() - - def configure_callbacks(self): - self.called.append("configure_callbacks") - return super().configure_callbacks() - - def configure_optimizers(self): - self.called.append("configure_optimizers") - return super().configure_optimizers() - - def training_step(self, *args, **kwargs): - self.called.append("training_step") - return super().training_step(*args, **kwargs) - - def optimizer_zero_grad(self, *args, **kwargs): - self.called.append("optimizer_zero_grad") - super().optimizer_zero_grad(*args, **kwargs) - - def training_epoch_end(self, *args, **kwargs): - self.called.append("training_epoch_end") - super().training_epoch_end(*args, **kwargs) - - def backward(self, *args, **kwargs): - self.called.append("backward") - super().backward(*args, **kwargs) - - def on_after_backward(self): - self.called.append("on_after_backward") - super().on_after_backward() - - def optimizer_step(self, *args, **kwargs): - super().optimizer_step(*args, **kwargs) - self.called.append("optimizer_step") # append after as closure calls other methods - - def validation_epoch_end(self, *args, **kwargs): - self.called.append("validation_epoch_end") - super().validation_epoch_end(*args, **kwargs) - - def on_before_zero_grad(self, *args, **kwargs): - self.called.append("on_before_zero_grad") - super().on_before_zero_grad(*args, **kwargs) - - def on_epoch_start(self): - self.called.append("on_epoch_start") - super().on_epoch_start() - - def on_epoch_end(self): - self.called.append("on_epoch_end") - super().on_epoch_end() - - def on_fit_start(self): - self.called.append("on_fit_start") - super().on_fit_start() - - def on_fit_end(self): - self.called.append("on_fit_end") - super().on_fit_end() - - def on_hpc_load(self, *args, **kwargs): - self.called.append("on_hpc_load") - super().on_hpc_load(*args, **kwargs) - - def on_hpc_save(self, *args, **kwargs): - self.called.append("on_hpc_save") - super().on_hpc_save(*args, **kwargs) - - def on_load_checkpoint(self, *args, **kwargs): - self.called.append("on_load_checkpoint") - super().on_load_checkpoint(*args, **kwargs) - - def on_save_checkpoint(self, *args, **kwargs): - self.called.append("on_save_checkpoint") - super().on_save_checkpoint(*args, **kwargs) - - def on_pretrain_routine_start(self): - self.called.append("on_pretrain_routine_start") - super().on_pretrain_routine_start() - - def on_pretrain_routine_end(self): - self.called.append("on_pretrain_routine_end") - super().on_pretrain_routine_end() - - def on_train_start(self): - self.called.append("on_train_start") - super().on_train_start() - - def on_train_end(self): - self.called.append("on_train_end") - super().on_train_end() - - def on_before_batch_transfer(self, *args, **kwargs): - self.called.append("on_before_batch_transfer") - return super().on_before_batch_transfer(*args, **kwargs) - - def transfer_batch_to_device(self, *args, **kwargs): - self.called.append("transfer_batch_to_device") - return super().transfer_batch_to_device(*args, **kwargs) - - def on_after_batch_transfer(self, *args, **kwargs): - self.called.append("on_after_batch_transfer") - return super().on_after_batch_transfer(*args, **kwargs) - - def on_train_batch_start(self, *args, **kwargs): - self.called.append("on_train_batch_start") - super().on_train_batch_start(*args, **kwargs) - - def on_train_batch_end(self, *args, **kwargs): - self.called.append("on_train_batch_end") - super().on_train_batch_end(*args, **kwargs) - - def on_train_epoch_start(self): - self.called.append("on_train_epoch_start") - super().on_train_epoch_start() - - def on_train_epoch_end(self): - self.called.append("on_train_epoch_end") - super().on_train_epoch_end() - - def on_validation_start(self): - self.called.append("on_validation_start") - super().on_validation_start() - - def on_validation_end(self): - self.called.append("on_validation_end") - super().on_validation_end() - - def on_validation_batch_start(self, *args, **kwargs): - self.called.append("on_validation_batch_start") - super().on_validation_batch_start(*args, **kwargs) - - def on_validation_batch_end(self, *args, **kwargs): - self.called.append("on_validation_batch_end") - super().on_validation_batch_end(*args, **kwargs) - - def on_validation_epoch_start(self): - self.called.append("on_validation_epoch_start") - super().on_validation_epoch_start() - - def on_validation_epoch_end(self, *args, **kwargs): - self.called.append("on_validation_epoch_end") - super().on_validation_epoch_end(*args, **kwargs) - - def on_test_start(self): - self.called.append("on_test_start") - super().on_test_start() - - def on_test_batch_start(self, *args, **kwargs): - self.called.append("on_test_batch_start") - super().on_test_batch_start(*args, **kwargs) - - def on_test_batch_end(self, *args, **kwargs): - self.called.append("on_test_batch_end") - super().on_test_batch_end(*args, **kwargs) - - def on_test_epoch_start(self): - self.called.append("on_test_epoch_start") - super().on_test_epoch_start() - - def on_test_epoch_end(self, *args, **kwargs): - self.called.append("on_test_epoch_end") - super().on_test_epoch_end(*args, **kwargs) - - def on_validation_model_eval(self): - self.called.append("on_validation_model_eval") - super().on_validation_model_eval() - - def on_validation_model_train(self): - self.called.append("on_validation_model_train") - super().on_validation_model_train() - - def on_test_model_eval(self): - self.called.append("on_test_model_eval") - super().on_test_model_eval() - - def on_test_model_train(self): - self.called.append("on_test_model_train") - super().on_test_model_train() - - def on_test_end(self): - self.called.append("on_test_end") - super().on_test_end() - - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) - - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") - super().teardown(stage) + @staticmethod + def _test_batch(): + return [ + 'on_test_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'forward', + 'test_step', + 'test_step_end', + 'on_test_batch_end', + ] def test_trainer_model_hook_system_fit(tmpdir): - model = HookedModel() + called = [] + model = HookedModel(called) train_batches = 2 val_batches = 2 trainer = Trainer( @@ -462,53 +319,67 @@ def test_trainer_model_hook_system_fit(tmpdir): progress_bar_refresh_rate=0, weights_summary=None, ) - assert model.called == [] + assert called == [] trainer.fit(model) expected = [ 'prepare_data', 'configure_callbacks', - 'setup_fit', + 'setup', + 'configure_sharded_model', 'configure_optimizers', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', + 'on_val_dataloader', + 'val_dataloader', + 'train', # eval() == train(False) 'on_validation_model_eval', + 'zero_grad', 'on_validation_start', 'on_epoch_start', 'on_validation_epoch_start', - *(model.val_batch * val_batches), + *(HookedModel._val_batch() * val_batches), 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_validation_end', + 'train', 'on_validation_model_train', + # duplicate `train` because `_run_train` calls it again in case validation wasn't run + 'train', + 'on_train_dataloader', + 'train_dataloader', 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - *(model.train_batch * train_batches), + *(HookedModel._train_batch() * train_batches), + 'train', # eval() == train(False) 'on_validation_model_eval', + 'zero_grad', 'on_validation_start', 'on_epoch_start', 'on_validation_epoch_start', - *(model.val_batch * val_batches), + *(HookedModel._val_batch() * val_batches), 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_save_checkpoint', 'on_validation_end', + 'train', 'on_validation_model_train', 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', 'on_train_end', 'on_fit_end', - 'teardown_fit', + 'teardown', ] - assert model.called == expected + assert called == expected def test_trainer_model_hook_system_fit_no_val(tmpdir): - model = HookedModel() + called = [] + model = HookedModel(called) train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, @@ -518,96 +389,115 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): progress_bar_refresh_rate=0, weights_summary=None, ) - assert model.called == [] + assert called == [] trainer.fit(model) expected = [ 'prepare_data', 'configure_callbacks', - 'setup_fit', + 'setup', + 'configure_sharded_model', 'configure_optimizers', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', + 'train', + 'on_train_dataloader', + 'train_dataloader', + # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches` + 'on_val_dataloader', + 'val_dataloader', 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - *(model.train_batch * train_batches), + *(HookedModel._train_batch() * train_batches), 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', 'on_save_checkpoint', # from train epoch end 'on_train_end', 'on_fit_end', - 'teardown_fit', + 'teardown', ] - assert model.called == expected + assert called == expected -def test_trainer_model_hook_system_validate(tmpdir): - model = HookedModel() +@pytest.mark.parametrize('batches', (0, 2)) +def test_trainer_model_hook_system_validate(tmpdir, batches): + called = [] + model = HookedModel(called) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_val_batches=1, + limit_val_batches=batches, progress_bar_refresh_rate=0, weights_summary=None, ) - assert model.called == [] + assert called == [] trainer.validate(model, verbose=False) - expected = [ - 'prepare_data', - 'configure_callbacks', - 'setup_validate', + hooks = [ + 'train', # eval() == train(False) 'on_validation_model_eval', + 'zero_grad', 'on_validation_start', 'on_epoch_start', 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_validation_batch_end', + *(HookedModel._val_batch() * batches), 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_validation_end', - 'on_validation_model_train', - 'teardown_validate', + 'train', + 'on_validation_model_train' ] - assert model.called == expected + expected = [ + 'prepare_data', + 'configure_callbacks', + 'setup', + 'configure_sharded_model', + 'on_val_dataloader', + 'val_dataloader', + *(hooks if batches else []), + 'teardown', + ] + assert called == expected def test_trainer_model_hook_system_test(tmpdir): - model = HookedModel() + called = [] + model = HookedModel(called) + batches = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_test_batches=1, + limit_test_batches=batches, progress_bar_refresh_rate=0, weights_summary=None, ) - assert model.called == [] + assert called == [] trainer.test(model, verbose=False) expected = [ 'prepare_data', 'configure_callbacks', - 'setup_test', + 'setup', + 'configure_sharded_model', + 'on_test_dataloader', + 'test_dataloader', + 'train', # eval() == train(False) 'on_test_model_eval', + 'zero_grad', 'on_test_start', 'on_epoch_start', 'on_test_epoch_start', - 'on_test_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_test_batch_end', + *(HookedModel._test_batch() * batches), + 'test_epoch_end', 'on_test_epoch_end', 'on_epoch_end', 'on_test_end', + 'train', 'on_test_model_train', - 'teardown_test', + 'teardown', ] - assert model.called == expected + assert called == expected def test_hooks_with_different_argument_names(tmpdir): @@ -681,8 +571,7 @@ def test_trainer_datamodule_hook_system(tmpdir): called.append(d) return out - hooks = {h for h, _ in getmembers(LightningDataModule, predicate=isfunction)} - for h in hooks: + for h in get_members(LightningDataModule): attr = getattr(self, h) setattr(self, h, partial(call, h, attr))