Mark stage argument in hooks as required (#14064)

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
Adrian Wälchli 2022-09-01 15:47:40 +02:00 committed by GitHub
parent 764b348249
commit 28e18881a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 56 additions and 105 deletions

View File

@ -326,7 +326,7 @@ The :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class provid
def prepare_data(self):
MNIST(self.data_dir, download=True)
def setup(self, stage: Optional[str] = None):
def setup(self, stage: str):
self.mnist = MNIST(self.data_dir)
def train_loader(self):

View File

@ -84,7 +84,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: Optional[str] = None):
def setup(self, stage: str):
self.mnist_test = MNIST(self.data_dir, train=False)
self.mnist_predict = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
@ -102,7 +102,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=self.batch_size)
def teardown(self, stage: Optional[str] = None):
def teardown(self, stage: str):
# Used to clean-up when the run is finished
...
@ -141,18 +141,18 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
if stage == "predict":
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
@ -226,15 +226,15 @@ There are also data operations you might want to perform on every GPU. Use :meth
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage: Optional[str] = None):
def setup(self, stage: str):
# Assign Train/val split(s) for use in Dataloaders
if stage in (None, "fit"):
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign Test split(s) for use in Dataloaders
if stage in (None, "test"):
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
@ -256,8 +256,7 @@ For eg., if you are working with NLP task where you need to tokenize the text an
This method expects a ``stage`` argument.
It is used to separate setup logic for ``trainer.{fit,validate,test,predict}``. If ``setup`` is called with ``stage=None``,
we assume all stages have been set-up.
It is used to separate setup logic for ``trainer.{fit,validate,test,predict}``.
.. note:: :ref:`setup<data/datamodule:setup>` is called from every process across all the nodes. Setting state here is recommended.
.. note:: :ref:`teardown<data/datamodule:teardown>` can be used to clean up the state. It is also called from every process across all the nodes.

View File

@ -125,7 +125,7 @@ class ImageNetLightningModel(LightningModule):
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch // 30))
return [optimizer], [scheduler]
def setup(self, stage: Optional[str] = None):
def setup(self, stage: str):
if isinstance(self.trainer.strategy, ParallelStrategy):
# When using a single GPU per process and per `DistributedDataParallel`, we need to divide the batch size
# ourselves based on the total number of GPUs we have
@ -133,7 +133,7 @@ class ImageNetLightningModel(LightningModule):
self.batch_size = int(self.batch_size / num_processes)
self.workers = int(self.workers / num_processes)
if stage in (None, "fit"):
if stage == "fit":
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dir = os.path.join(self.data_path, "train")
self.train_dataset = datasets.ImageFolder(

View File

@ -83,7 +83,7 @@ class MNISTKFoldDataModule(BaseKFoldDataModule):
# download the data.
MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
def setup(self, stage: Optional[str] = None) -> None:
def setup(self, stage: str) -> None:
# load the data
dataset = MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000])

View File

@ -1,6 +1,6 @@
import inspect
import logging
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, TYPE_CHECKING, Union
from core.state import ProgressBarState, TrainerState
@ -31,7 +31,7 @@ class PLAppProgressTracker(Callback):
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
stage: Optional[str] = None,
stage: str,
) -> None:
self.is_enabled = trainer.is_global_zero
@ -261,7 +261,7 @@ class PLAppSummary(Callback):
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
stage: Optional[str] = None,
stage: str,
) -> None:
self.work.model_hparams = self._sanitize_model_init_args(dict(**pl_module.hparams))
@ -284,7 +284,7 @@ class PLAppArtifactsTracker(Callback):
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
stage: Optional[str] = None,
stage: str,
) -> None:
log_dir = self._get_logdir(trainer)
self.work.log_dir = Path(log_dir) if log_dir is not None else None

View File

@ -72,10 +72,10 @@ class Callback:
Called before accelerator is being setup.
"""
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Called when fit, validate, test, predict, or tune begins."""
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Called when fit, validate, test, predict, or tune ends."""
def on_init_start(self, trainer: "pl.Trainer") -> None:

View File

@ -58,7 +58,7 @@ class DeviceStatsMonitor(Callback):
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
stage: Optional[str] = None,
stage: str,
) -> None:
if stage != "fit":
return

View File

@ -129,7 +129,7 @@ class EarlyStopping(Callback):
def state_key(self) -> str:
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if self._check_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end

View File

@ -244,7 +244,7 @@ class BaseFinetuning(Callback):
if params:
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.freeze_before_training(pl_module)
@staticmethod

View File

@ -254,7 +254,7 @@ class ModelCheckpoint(Checkpoint):
save_on_train_epoch_end=self._save_on_train_epoch_end,
)
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.__resolve_ckpt_dir(trainer)
assert self.dirpath is not None
if trainer.is_global_zero and stage == "fit":

View File

@ -217,7 +217,7 @@ class ProgressBarBase(Callback):
"""You should provide a way to print without breaking the progress bar."""
print(*args, **kwargs)
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self._trainer = trainer
if not trainer.is_global_zero:
self.disable()

View File

@ -476,7 +476,7 @@ class RichProgressBar(ProgressBarBase):
if self._metric_component:
self._metric_component.update(metrics)
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
def teardown(self, trainer, pl_module, stage: str) -> None:
self._stop_progress()
def on_exception(self, trainer, pl_module, exception: BaseException) -> None:

View File

@ -361,7 +361,7 @@ class ModelPruning(Callback):
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
)
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
parameters_to_prune = self.sanitize_parameters_to_prune(
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
)

View File

@ -143,7 +143,7 @@ class StochasticWeightAveraging(Callback):
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
# copy the model before moving it to accelerator device.
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self._average_model = deepcopy(pl_module)

View File

@ -205,7 +205,7 @@ class SaveConfigCallback(Callback):
self.overwrite = overwrite
self.multifile = multifile
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
log_dir = trainer.log_dir # this broadcasts the directory
assert log_dir is not None
config_path = os.path.join(log_dir, self.config_filename)

View File

@ -380,7 +380,7 @@ class DataHooks:
model.predict_dataloader()
"""
def setup(self, stage: Optional[str] = None) -> None:
def setup(self, stage: str) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when
you need to build models dynamically or adjust something about them. This hook is called on every process
when using DDP.
@ -406,7 +406,7 @@ class DataHooks:
self.l1 = nn.Linear(28, data.num_classes)
"""
def teardown(self, stage: Optional[str] = None) -> None:
def teardown(self, stage: str) -> None:
"""Called at the end of fit (train + validate), validate, test, or predict.
Args:

View File

@ -163,17 +163,17 @@ class BoringDataModule(LightningDataModule):
self.checkpoint_state: Optional[str] = None
self.random_full = RandomDataset(32, 64 * 4)
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
def setup(self, stage: str) -> None:
if stage == "fit":
self.random_train = Subset(self.random_full, indices=range(64))
if stage in ("fit", "validate") or stage is None:
if stage in ("fit", "validate"):
self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
if stage == "test" or stage is None:
if stage == "test":
self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
if stage == "predict" or stage is None:
if stage == "predict":
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
def train_dataloader(self) -> DataLoader:

View File

@ -195,7 +195,7 @@ class MNISTDataModule(LightningDataModule):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None) -> None:
def setup(self, stage: str) -> None:
"""Split the train and valid dataset."""
extra = dict(transform=self.default_transforms) if self.default_transforms else {}
dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra)

View File

@ -78,7 +78,7 @@ class AdvancedProfiler(Profiler):
recorded_stats[action_name] = s.getvalue()
return self._stats_to_str(recorded_stats)
def teardown(self, stage: Optional[str] = None) -> None:
def teardown(self, stage: Optional[str]) -> None:
super().teardown(stage=stage)
self.profiled_actions = {}

View File

@ -148,15 +148,13 @@ class Profiler(ABC):
output.append(value)
return os.linesep.join(output)
def setup(
self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None
) -> None:
def setup(self, stage: str, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
"""Execute arbitrary pre-profiling set-up steps."""
self._stage = stage
self._local_rank = local_rank
self.dirpath = self.dirpath or log_dir
def teardown(self, stage: Optional[str] = None) -> None:
def teardown(self, stage: Optional[str]) -> None:
"""Execute arbitrary post-profiling tear-down steps.
Closes the currently open file and stream.

View File

@ -505,7 +505,7 @@ class PyTorchProfiler(Profiler):
self._register.__exit__(None, None, None)
self._register = None
def teardown(self, stage: Optional[str] = None) -> None:
def teardown(self, stage: str) -> None:
self._delete_profilers()
for k in list(self._recording_map):

View File

@ -56,7 +56,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_on_pretrain_routine(model)
# TODO: Delete CheckpointHooks off LightningDataModule in v1.8
_check_datamodule_checkpoint_hooks(trainer)
_check_setup_method(trainer)
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
@ -309,9 +308,3 @@ def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
"`LightningDataModule.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
)
def _check_setup_method(trainer: "pl.Trainer") -> None:
for obj in [trainer.lightning_module, trainer.datamodule] + trainer.callbacks:
if is_overridden("setup", obj) and not is_param_in_hook_signature(obj.setup, "stage"):
raise MisconfigurationException(f"`{obj.__class__.__name__}.setup` does not have a `stage` argument.")

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
from unittest import mock
import pytest
@ -187,7 +186,7 @@ def test_optimization(tmpdir):
@RunIf(ipu=True)
def test_half_precision(tmpdir):
class TestCallback(Callback):
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
assert trainer.strategy.model.precision == 16
raise SystemExit

View File

@ -22,7 +22,7 @@ def test_main_progress_bar_with_val_check_interval_int():
limit_train_batches=train_batches, limit_val_batches=10, val_check_interval=3, check_val_every_n_epoch=None
)
model = BoringModel()
trainer.progress_bar_callback.setup(trainer, model)
trainer.progress_bar_callback.setup(trainer, model, stage="fit")
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model)
trainer.reset_train_dataloader()

View File

@ -1245,7 +1245,7 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
# Case - 2
# Make sure that everything runs when dirpath is not initialized explicitly
cb_restore = CustomModelCheckpoint()
cb_restore.setup(Trainer(), BoringModel())
cb_restore.setup(Trainer(), BoringModel(), stage="fit")
with pytest.warns(UserWarning, match="The dirpath has changed from*"):
cb_restore.load_state_dict(written_ckpt)
make_assertions(cb_restore, written_ckpt)

View File

@ -117,7 +117,7 @@ def test_hooks_no_recursion_error():
def test_helper_boringdatamodule():
dm = BoringDataModule()
dm.prepare_data()
dm.setup()
dm.setup("fit")
def test_helper_boringdatamodule_with_verbose_setup():
@ -140,7 +140,7 @@ def test_dm_init_from_argparse_args(tmpdir):
args = parser.parse_args(["--data_dir", str(tmpdir)])
dm = BoringDataModule.from_argparse_args(args)
dm.prepare_data()
dm.setup()
dm.setup("fit")
assert dm.data_dir == args.data_dir == str(tmpdir)

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import pytest
import torch
@ -42,10 +41,10 @@ class MNISTDataModule(LightningDataModule):
self.dataset_cls(self.data_dir, train=True, download=True)
self.dataset_cls(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
if stage == "fit" or stage is None:
def setup(self, stage: str):
if stage == "fit":
self.mnist_train = self.dataset_cls(self.data_dir, train=True)
if stage == "test" or stage is None:
if stage == "test":
self.mnist_test = self.dataset_cls(self.data_dir, train=False)
def train_dataloader(self):

View File

@ -860,7 +860,7 @@ def test_no_datamodule_for_hparams(tmpdir):
model = SaveHparamsModel({"arg1": 5, "arg2": "abc"})
org_model_hparams = copy.deepcopy(model.hparams_initial)
data = DataModuleWithoutHparams()
data.setup()
data.setup("fit")
mock_logger = _get_mock_logger(tmpdir)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger)

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import pytest
import torch
@ -42,7 +41,7 @@ def test_precision_plugin(hmp_params):
@RunIf(hpu=True)
def test_mixed_precision(tmpdir, hmp_params: dict):
class TestCallback(Callback):
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
assert trainer.strategy.model.precision == "bf16"
raise SystemExit

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
from unittest import mock
from unittest.mock import patch
@ -94,7 +93,7 @@ def test_ddp_torch_dist_is_available_in_setup(
"""Test to ensure torch distributed is available within the setup hook using ddp."""
class TestModel(BoringModel):
def setup(self, stage: Optional[str] = None) -> None:
def setup(self, stage: str) -> None:
assert torch.distributed.is_initialized()
raise SystemExit()

View File

@ -15,7 +15,7 @@ import contextlib
import json
import logging
import os
from typing import Any, Dict, Optional
from typing import Any, Dict
from unittest import mock
import pytest
@ -263,7 +263,7 @@ def test_deepspeed_auto_batch_size_config_select(mock_deepspeed_distributed, moc
return DataLoader(dataset_cls(32, 64))
class AssertCallback(Callback):
def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None:
def setup(self, trainer, pl_module, stage: str) -> None:
assert isinstance(trainer.strategy, DeepSpeedStrategy)
config = trainer.strategy.config
@ -1059,7 +1059,7 @@ def test_deepspeed_setup_train_dataloader(tmpdir):
super().__init__()
self._setup = False
def setup(self, stage: Optional[str] = None) -> None:
def setup(self, stage: str) -> None:
self._setup = True
def train_dataloader(self):

View File

@ -16,8 +16,7 @@ import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import PossibleUserWarning
@ -162,40 +161,6 @@ def test_trainer_manual_optimization_config(tmpdir):
trainer.fit(model)
def test_invalid_setup_method():
"""Test error message when `setup` method of `LightningModule` or `LightningDataModule` is not defined
correctly."""
class CustomModel(BoringModel):
def setup(self):
pass
class CustomDataModule(BoringDataModule):
def setup(self):
pass
class CustomBoringCallback(Callback):
def setup(self, pl_module, trainer):
pass
fit_kwargs = [
{"model": CustomModel(), "datamodule": BoringDataModule()},
{"model": BoringModel(), "datamodule": CustomDataModule()},
]
for kwargs in fit_kwargs:
trainer = Trainer(fast_dev_run=True)
with pytest.raises(MisconfigurationException, match="does not have a `stage` argument"):
trainer.fit(**kwargs)
trainer = Trainer(fast_dev_run=True, callbacks=[CustomBoringCallback()])
model = BoringModel()
with pytest.raises(MisconfigurationException, match="does not have a `stage` argument"):
trainer.fit(model)
@pytest.mark.parametrize("trainer_kwargs", [{"accelerator": "ipu"}, {"accelerator": "gpu", "strategy": "dp"}])
@pytest.mark.parametrize("hook", ["transfer_batch_to_device", "on_after_batch_transfer"])
def test_raise_exception_with_batch_transfer_hooks(monkeypatch, hook, trainer_kwargs, tmpdir):