From 28e18881a9ad2298169c78ad9ae109191e201c2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 1 Sep 2022 15:47:40 +0200 Subject: [PATCH] Mark stage argument in hooks as required (#14064) Co-authored-by: rohitgr7 --- .../advanced/training_tricks.rst | 2 +- docs/source-pytorch/data/datamodule.rst | 21 +++++------ examples/pl_domain_templates/imagenet.py | 4 +- examples/pl_loops/kfold.py | 2 +- .../cli/pl-app-template/core/callbacks.py | 8 ++-- src/pytorch_lightning/callbacks/callback.py | 4 +- .../callbacks/device_stats_monitor.py | 2 +- .../callbacks/early_stopping.py | 2 +- src/pytorch_lightning/callbacks/finetuning.py | 2 +- .../callbacks/model_checkpoint.py | 2 +- .../callbacks/progress/base.py | 2 +- .../callbacks/progress/rich_progress.py | 2 +- src/pytorch_lightning/callbacks/pruning.py | 2 +- .../callbacks/stochastic_weight_avg.py | 2 +- src/pytorch_lightning/cli.py | 2 +- src/pytorch_lightning/core/hooks.py | 4 +- src/pytorch_lightning/demos/boring_classes.py | 10 ++--- .../demos/mnist_datamodule.py | 2 +- src/pytorch_lightning/profilers/advanced.py | 2 +- src/pytorch_lightning/profilers/profiler.py | 6 +-- src/pytorch_lightning/profilers/pytorch.py | 2 +- .../trainer/configuration_validator.py | 7 ---- tests/tests_pytorch/accelerators/test_ipu.py | 3 +- .../callbacks/progress/test_base_progress.py | 2 +- .../checkpointing/test_model_checkpoint.py | 2 +- tests/tests_pytorch/core/test_datamodules.py | 4 +- tests/tests_pytorch/helpers/datamodules.py | 7 ++-- tests/tests_pytorch/models/test_hparams.py | 2 +- .../plugins/precision/hpu/test_hpu.py | 3 +- tests/tests_pytorch/strategies/test_ddp.py | 3 +- .../strategies/test_deepspeed_strategy.py | 6 +-- .../trainer/test_config_validator.py | 37 +------------------ 32 files changed, 56 insertions(+), 105 deletions(-) diff --git a/docs/source-pytorch/advanced/training_tricks.rst b/docs/source-pytorch/advanced/training_tricks.rst index 76d2f43176..71a778fa09 100644 --- a/docs/source-pytorch/advanced/training_tricks.rst +++ b/docs/source-pytorch/advanced/training_tricks.rst @@ -326,7 +326,7 @@ The :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class provid def prepare_data(self): MNIST(self.data_dir, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str): self.mnist = MNIST(self.data_dir) def train_loader(self): diff --git a/docs/source-pytorch/data/datamodule.rst b/docs/source-pytorch/data/datamodule.rst index 62a0f9d0d5..fbee2e80e4 100644 --- a/docs/source-pytorch/data/datamodule.rst +++ b/docs/source-pytorch/data/datamodule.rst @@ -84,7 +84,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa self.data_dir = data_dir self.batch_size = batch_size - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str): self.mnist_test = MNIST(self.data_dir, train=False) self.mnist_predict = MNIST(self.data_dir, train=False) mnist_full = MNIST(self.data_dir, train=True) @@ -102,7 +102,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def predict_dataloader(self): return DataLoader(self.mnist_predict, batch_size=self.batch_size) - def teardown(self, stage: Optional[str] = None): + def teardown(self, stage: str): # Used to clean-up when the run is finished ... @@ -141,18 +141,18 @@ Here's a more realistic, complex DataModule that shows how much more reusable th MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str): # Assign train/val datasets for use in dataloaders - if stage == "fit" or stage is None: + if stage == "fit": mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign test dataset for use in dataloader(s) - if stage == "test" or stage is None: + if stage == "test": self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) - if stage == "predict" or stage is None: + if stage == "predict": self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform) def train_dataloader(self): @@ -226,15 +226,15 @@ There are also data operations you might want to perform on every GPU. Use :meth class MNISTDataModule(pl.LightningDataModule): - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str): # Assign Train/val split(s) for use in Dataloaders - if stage in (None, "fit"): + if stage == "fit": mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign Test split(s) for use in Dataloaders - if stage in (None, "test"): + if stage == "test": self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform) @@ -256,8 +256,7 @@ For eg., if you are working with NLP task where you need to tokenize the text an This method expects a ``stage`` argument. -It is used to separate setup logic for ``trainer.{fit,validate,test,predict}``. If ``setup`` is called with ``stage=None``, -we assume all stages have been set-up. +It is used to separate setup logic for ``trainer.{fit,validate,test,predict}``. .. note:: :ref:`setup` is called from every process across all the nodes. Setting state here is recommended. .. note:: :ref:`teardown` can be used to clean up the state. It is also called from every process across all the nodes. diff --git a/examples/pl_domain_templates/imagenet.py b/examples/pl_domain_templates/imagenet.py index 93284963db..efb9c40eea 100644 --- a/examples/pl_domain_templates/imagenet.py +++ b/examples/pl_domain_templates/imagenet.py @@ -125,7 +125,7 @@ class ImageNetLightningModel(LightningModule): scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch // 30)) return [optimizer], [scheduler] - def setup(self, stage: Optional[str] = None): + def setup(self, stage: str): if isinstance(self.trainer.strategy, ParallelStrategy): # When using a single GPU per process and per `DistributedDataParallel`, we need to divide the batch size # ourselves based on the total number of GPUs we have @@ -133,7 +133,7 @@ class ImageNetLightningModel(LightningModule): self.batch_size = int(self.batch_size / num_processes) self.workers = int(self.workers / num_processes) - if stage in (None, "fit"): + if stage == "fit": normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dir = os.path.join(self.data_path, "train") self.train_dataset = datasets.ImageFolder( diff --git a/examples/pl_loops/kfold.py b/examples/pl_loops/kfold.py index 028e0be698..529f0c6e1b 100644 --- a/examples/pl_loops/kfold.py +++ b/examples/pl_loops/kfold.py @@ -83,7 +83,7 @@ class MNISTKFoldDataModule(BaseKFoldDataModule): # download the data. MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: str) -> None: # load the data dataset = MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000]) diff --git a/src/lightning_app/cli/pl-app-template/core/callbacks.py b/src/lightning_app/cli/pl-app-template/core/callbacks.py index de1bb4003f..f324d10f1f 100644 --- a/src/lightning_app/cli/pl-app-template/core/callbacks.py +++ b/src/lightning_app/cli/pl-app-template/core/callbacks.py @@ -1,6 +1,6 @@ import inspect import logging -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, TYPE_CHECKING, Union from core.state import ProgressBarState, TrainerState @@ -31,7 +31,7 @@ class PLAppProgressTracker(Callback): self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", - stage: Optional[str] = None, + stage: str, ) -> None: self.is_enabled = trainer.is_global_zero @@ -261,7 +261,7 @@ class PLAppSummary(Callback): self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", - stage: Optional[str] = None, + stage: str, ) -> None: self.work.model_hparams = self._sanitize_model_init_args(dict(**pl_module.hparams)) @@ -284,7 +284,7 @@ class PLAppArtifactsTracker(Callback): self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", - stage: Optional[str] = None, + stage: str, ) -> None: log_dir = self._get_logdir(trainer) self.work.log_dir = Path(log_dir) if log_dir is not None else None diff --git a/src/pytorch_lightning/callbacks/callback.py b/src/pytorch_lightning/callbacks/callback.py index 892bd0fdfb..cf57c5c2f7 100644 --- a/src/pytorch_lightning/callbacks/callback.py +++ b/src/pytorch_lightning/callbacks/callback.py @@ -72,10 +72,10 @@ class Callback: Called before accelerator is being setup. """ - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: """Called when fit, validate, test, predict, or tune begins.""" - def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: """Called when fit, validate, test, predict, or tune ends.""" def on_init_start(self, trainer: "pl.Trainer") -> None: diff --git a/src/pytorch_lightning/callbacks/device_stats_monitor.py b/src/pytorch_lightning/callbacks/device_stats_monitor.py index ed6750496a..c062fea8db 100644 --- a/src/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/src/pytorch_lightning/callbacks/device_stats_monitor.py @@ -58,7 +58,7 @@ class DeviceStatsMonitor(Callback): self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", - stage: Optional[str] = None, + stage: str, ) -> None: if stage != "fit": return diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 87585bb812..79ba68e194 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -129,7 +129,7 @@ class EarlyStopping(Callback): def state_key(self) -> str: return self._generate_state_key(monitor=self.monitor, mode=self.mode) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: if self._check_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch or multiple training epochs without # validation, then we run after validation instead of on train epoch end diff --git a/src/pytorch_lightning/callbacks/finetuning.py b/src/pytorch_lightning/callbacks/finetuning.py index d2afdd20bd..11cd81f7a2 100644 --- a/src/pytorch_lightning/callbacks/finetuning.py +++ b/src/pytorch_lightning/callbacks/finetuning.py @@ -244,7 +244,7 @@ class BaseFinetuning(Callback): if params: optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr}) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: self.freeze_before_training(pl_module) @staticmethod diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 1ad86a0917..3362d07902 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -254,7 +254,7 @@ class ModelCheckpoint(Checkpoint): save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: self.__resolve_ckpt_dir(trainer) assert self.dirpath is not None if trainer.is_global_zero and stage == "fit": diff --git a/src/pytorch_lightning/callbacks/progress/base.py b/src/pytorch_lightning/callbacks/progress/base.py index 003cc7bc6f..4fd4597c99 100644 --- a/src/pytorch_lightning/callbacks/progress/base.py +++ b/src/pytorch_lightning/callbacks/progress/base.py @@ -217,7 +217,7 @@ class ProgressBarBase(Callback): """You should provide a way to print without breaking the progress bar.""" print(*args, **kwargs) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: self._trainer = trainer if not trainer.is_global_zero: self.disable() diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index 8ca2cb6671..e0d8fca2e7 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -476,7 +476,7 @@ class RichProgressBar(ProgressBarBase): if self._metric_component: self._metric_component.update(metrics) - def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: + def teardown(self, trainer, pl_module, stage: str) -> None: self._stop_progress() def on_exception(self, trainer, pl_module, exception: BaseException) -> None: diff --git a/src/pytorch_lightning/callbacks/pruning.py b/src/pytorch_lightning/callbacks/pruning.py index 63516028b1..878fe674b8 100644 --- a/src/pytorch_lightning/callbacks/pruning.py +++ b/src/pytorch_lightning/callbacks/pruning.py @@ -361,7 +361,7 @@ class ModelPruning(Callback): f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names ) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 6650bb3f0c..90e2c62a79 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -143,7 +143,7 @@ class StochasticWeightAveraging(Callback): def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: # copy the model before moving it to accelerator device. with pl_module._prevent_trainer_and_dataloaders_deepcopy(): self._average_model = deepcopy(pl_module) diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index d3990d79c5..700307b6ef 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -205,7 +205,7 @@ class SaveConfigCallback(Callback): self.overwrite = overwrite self.multifile = multifile - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: log_dir = trainer.log_dir # this broadcasts the directory assert log_dir is not None config_path = os.path.join(log_dir, self.config_filename) diff --git a/src/pytorch_lightning/core/hooks.py b/src/pytorch_lightning/core/hooks.py index 4da53903ec..86b3d3f92e 100644 --- a/src/pytorch_lightning/core/hooks.py +++ b/src/pytorch_lightning/core/hooks.py @@ -380,7 +380,7 @@ class DataHooks: model.predict_dataloader() """ - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: str) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. @@ -406,7 +406,7 @@ class DataHooks: self.l1 = nn.Linear(28, data.num_classes) """ - def teardown(self, stage: Optional[str] = None) -> None: + def teardown(self, stage: str) -> None: """Called at the end of fit (train + validate), validate, test, or predict. Args: diff --git a/src/pytorch_lightning/demos/boring_classes.py b/src/pytorch_lightning/demos/boring_classes.py index f7be539046..7d79f89163 100644 --- a/src/pytorch_lightning/demos/boring_classes.py +++ b/src/pytorch_lightning/demos/boring_classes.py @@ -163,17 +163,17 @@ class BoringDataModule(LightningDataModule): self.checkpoint_state: Optional[str] = None self.random_full = RandomDataset(32, 64 * 4) - def setup(self, stage: Optional[str] = None) -> None: - if stage == "fit" or stage is None: + def setup(self, stage: str) -> None: + if stage == "fit": self.random_train = Subset(self.random_full, indices=range(64)) - if stage in ("fit", "validate") or stage is None: + if stage in ("fit", "validate"): self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) - if stage == "test" or stage is None: + if stage == "test": self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) - if stage == "predict" or stage is None: + if stage == "predict": self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) def train_dataloader(self) -> DataLoader: diff --git a/src/pytorch_lightning/demos/mnist_datamodule.py b/src/pytorch_lightning/demos/mnist_datamodule.py index 6466b78250..e1818e83f4 100644 --- a/src/pytorch_lightning/demos/mnist_datamodule.py +++ b/src/pytorch_lightning/demos/mnist_datamodule.py @@ -195,7 +195,7 @@ class MNISTDataModule(LightningDataModule): MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: str) -> None: """Split the train and valid dataset.""" extra = dict(transform=self.default_transforms) if self.default_transforms else {} dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra) diff --git a/src/pytorch_lightning/profilers/advanced.py b/src/pytorch_lightning/profilers/advanced.py index 90fddc8074..73be0de3f8 100644 --- a/src/pytorch_lightning/profilers/advanced.py +++ b/src/pytorch_lightning/profilers/advanced.py @@ -78,7 +78,7 @@ class AdvancedProfiler(Profiler): recorded_stats[action_name] = s.getvalue() return self._stats_to_str(recorded_stats) - def teardown(self, stage: Optional[str] = None) -> None: + def teardown(self, stage: Optional[str]) -> None: super().teardown(stage=stage) self.profiled_actions = {} diff --git a/src/pytorch_lightning/profilers/profiler.py b/src/pytorch_lightning/profilers/profiler.py index 1b36159837..755007ba74 100644 --- a/src/pytorch_lightning/profilers/profiler.py +++ b/src/pytorch_lightning/profilers/profiler.py @@ -148,15 +148,13 @@ class Profiler(ABC): output.append(value) return os.linesep.join(output) - def setup( - self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None - ) -> None: + def setup(self, stage: str, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None: """Execute arbitrary pre-profiling set-up steps.""" self._stage = stage self._local_rank = local_rank self.dirpath = self.dirpath or log_dir - def teardown(self, stage: Optional[str] = None) -> None: + def teardown(self, stage: Optional[str]) -> None: """Execute arbitrary post-profiling tear-down steps. Closes the currently open file and stream. diff --git a/src/pytorch_lightning/profilers/pytorch.py b/src/pytorch_lightning/profilers/pytorch.py index 079aafe37e..9b843dccbf 100644 --- a/src/pytorch_lightning/profilers/pytorch.py +++ b/src/pytorch_lightning/profilers/pytorch.py @@ -505,7 +505,7 @@ class PyTorchProfiler(Profiler): self._register.__exit__(None, None, None) self._register = None - def teardown(self, stage: Optional[str] = None) -> None: + def teardown(self, stage: str) -> None: self._delete_profilers() for k in list(self._recording_map): diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index 8bf68e1bed..6ec2b15a11 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -56,7 +56,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: _check_on_pretrain_routine(model) # TODO: Delete CheckpointHooks off LightningDataModule in v1.8 _check_datamodule_checkpoint_hooks(trainer) - _check_setup_method(trainer) def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: @@ -309,9 +308,3 @@ def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None: "`LightningDataModule.on_load_checkpoint` was deprecated in" " v1.6 and will be removed in v1.8. Use `load_state_dict` instead." ) - - -def _check_setup_method(trainer: "pl.Trainer") -> None: - for obj in [trainer.lightning_module, trainer.datamodule] + trainer.callbacks: - if is_overridden("setup", obj) and not is_param_in_hook_signature(obj.setup, "stage"): - raise MisconfigurationException(f"`{obj.__class__.__name__}.setup` does not have a `stage` argument.") diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 470cb4a028..d5958eae0e 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional from unittest import mock import pytest @@ -187,7 +186,7 @@ def test_optimization(tmpdir): @RunIf(ipu=True) def test_half_precision(tmpdir): class TestCallback(Callback): - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: assert trainer.strategy.model.precision == 16 raise SystemExit diff --git a/tests/tests_pytorch/callbacks/progress/test_base_progress.py b/tests/tests_pytorch/callbacks/progress/test_base_progress.py index 75f276a6b9..588ec782d7 100644 --- a/tests/tests_pytorch/callbacks/progress/test_base_progress.py +++ b/tests/tests_pytorch/callbacks/progress/test_base_progress.py @@ -22,7 +22,7 @@ def test_main_progress_bar_with_val_check_interval_int(): limit_train_batches=train_batches, limit_val_batches=10, val_check_interval=3, check_val_every_n_epoch=None ) model = BoringModel() - trainer.progress_bar_callback.setup(trainer, model) + trainer.progress_bar_callback.setup(trainer, model, stage="fit") trainer.strategy.connect(model) trainer._data_connector.attach_data(model) trainer.reset_train_dataloader() diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 60ec4ec0f2..ebe0769d8d 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1245,7 +1245,7 @@ def test_model_checkpoint_saveload_ckpt(tmpdir): # Case - 2 # Make sure that everything runs when dirpath is not initialized explicitly cb_restore = CustomModelCheckpoint() - cb_restore.setup(Trainer(), BoringModel()) + cb_restore.setup(Trainer(), BoringModel(), stage="fit") with pytest.warns(UserWarning, match="The dirpath has changed from*"): cb_restore.load_state_dict(written_ckpt) make_assertions(cb_restore, written_ckpt) diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 23419c102e..19fb5181b2 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -117,7 +117,7 @@ def test_hooks_no_recursion_error(): def test_helper_boringdatamodule(): dm = BoringDataModule() dm.prepare_data() - dm.setup() + dm.setup("fit") def test_helper_boringdatamodule_with_verbose_setup(): @@ -140,7 +140,7 @@ def test_dm_init_from_argparse_args(tmpdir): args = parser.parse_args(["--data_dir", str(tmpdir)]) dm = BoringDataModule.from_argparse_args(args) dm.prepare_data() - dm.setup() + dm.setup("fit") assert dm.data_dir == args.data_dir == str(tmpdir) diff --git a/tests/tests_pytorch/helpers/datamodules.py b/tests/tests_pytorch/helpers/datamodules.py index 6ad3151f3a..4278422593 100644 --- a/tests/tests_pytorch/helpers/datamodules.py +++ b/tests/tests_pytorch/helpers/datamodules.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import pytest import torch @@ -42,10 +41,10 @@ class MNISTDataModule(LightningDataModule): self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): - if stage == "fit" or stage is None: + def setup(self, stage: str): + if stage == "fit": self.mnist_train = self.dataset_cls(self.data_dir, train=True) - if stage == "test" or stage is None: + if stage == "test": self.mnist_test = self.dataset_cls(self.data_dir, train=False) def train_dataloader(self): diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 84311d6f78..628eb28403 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -860,7 +860,7 @@ def test_no_datamodule_for_hparams(tmpdir): model = SaveHparamsModel({"arg1": 5, "arg2": "abc"}) org_model_hparams = copy.deepcopy(model.hparams_initial) data = DataModuleWithoutHparams() - data.setup() + data.setup("fit") mock_logger = _get_mock_logger(tmpdir) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger) diff --git a/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py b/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py index 5ca366f516..e8fc226435 100644 --- a/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py +++ b/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import pytest import torch @@ -42,7 +41,7 @@ def test_precision_plugin(hmp_params): @RunIf(hpu=True) def test_mixed_precision(tmpdir, hmp_params: dict): class TestCallback(Callback): - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: assert trainer.strategy.model.precision == "bf16" raise SystemExit diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 9b196f3e2a..dbde198b6e 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional from unittest import mock from unittest.mock import patch @@ -94,7 +93,7 @@ def test_ddp_torch_dist_is_available_in_setup( """Test to ensure torch distributed is available within the setup hook using ddp.""" class TestModel(BoringModel): - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: str) -> None: assert torch.distributed.is_initialized() raise SystemExit() diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index e3c6f95f3f..857abaa8df 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -15,7 +15,7 @@ import contextlib import json import logging import os -from typing import Any, Dict, Optional +from typing import Any, Dict from unittest import mock import pytest @@ -263,7 +263,7 @@ def test_deepspeed_auto_batch_size_config_select(mock_deepspeed_distributed, moc return DataLoader(dataset_cls(32, 64)) class AssertCallback(Callback): - def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None: + def setup(self, trainer, pl_module, stage: str) -> None: assert isinstance(trainer.strategy, DeepSpeedStrategy) config = trainer.strategy.config @@ -1059,7 +1059,7 @@ def test_deepspeed_setup_train_dataloader(tmpdir): super().__init__() self._setup = False - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: str) -> None: self._setup = True def train_dataloader(self): diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index 7fba63ba7a..f6508c181e 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -16,8 +16,7 @@ import torch import pytorch_lightning as pl from pytorch_lightning import LightningDataModule, LightningModule, Trainer -from pytorch_lightning.callbacks.callback import Callback -from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import PossibleUserWarning @@ -162,40 +161,6 @@ def test_trainer_manual_optimization_config(tmpdir): trainer.fit(model) -def test_invalid_setup_method(): - """Test error message when `setup` method of `LightningModule` or `LightningDataModule` is not defined - correctly.""" - - class CustomModel(BoringModel): - def setup(self): - pass - - class CustomDataModule(BoringDataModule): - def setup(self): - pass - - class CustomBoringCallback(Callback): - def setup(self, pl_module, trainer): - pass - - fit_kwargs = [ - {"model": CustomModel(), "datamodule": BoringDataModule()}, - {"model": BoringModel(), "datamodule": CustomDataModule()}, - ] - - for kwargs in fit_kwargs: - trainer = Trainer(fast_dev_run=True) - - with pytest.raises(MisconfigurationException, match="does not have a `stage` argument"): - trainer.fit(**kwargs) - - trainer = Trainer(fast_dev_run=True, callbacks=[CustomBoringCallback()]) - model = BoringModel() - - with pytest.raises(MisconfigurationException, match="does not have a `stage` argument"): - trainer.fit(model) - - @pytest.mark.parametrize("trainer_kwargs", [{"accelerator": "ipu"}, {"accelerator": "gpu", "strategy": "dp"}]) @pytest.mark.parametrize("hook", ["transfer_batch_to_device", "on_after_batch_transfer"]) def test_raise_exception_with_batch_transfer_hooks(monkeypatch, hook, trainer_kwargs, tmpdir):