Mark stage argument in hooks as required (#14064)
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
parent
764b348249
commit
28e18881a9
|
@ -326,7 +326,7 @@ The :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class provid
|
|||
def prepare_data(self):
|
||||
MNIST(self.data_dir, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
def setup(self, stage: str):
|
||||
self.mnist = MNIST(self.data_dir)
|
||||
|
||||
def train_loader(self):
|
||||
|
|
|
@ -84,7 +84,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
|
|||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
def setup(self, stage: str):
|
||||
self.mnist_test = MNIST(self.data_dir, train=False)
|
||||
self.mnist_predict = MNIST(self.data_dir, train=False)
|
||||
mnist_full = MNIST(self.data_dir, train=True)
|
||||
|
@ -102,7 +102,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
|
|||
def predict_dataloader(self):
|
||||
return DataLoader(self.mnist_predict, batch_size=self.batch_size)
|
||||
|
||||
def teardown(self, stage: Optional[str] = None):
|
||||
def teardown(self, stage: str):
|
||||
# Used to clean-up when the run is finished
|
||||
...
|
||||
|
||||
|
@ -141,18 +141,18 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
|
|||
MNIST(self.data_dir, train=True, download=True)
|
||||
MNIST(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
def setup(self, stage: str):
|
||||
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
if stage == "fit" or stage is None:
|
||||
if stage == "fit":
|
||||
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
if stage == "test" or stage is None:
|
||||
if stage == "test":
|
||||
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
|
||||
|
||||
if stage == "predict" or stage is None:
|
||||
if stage == "predict":
|
||||
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
|
||||
|
||||
def train_dataloader(self):
|
||||
|
@ -226,15 +226,15 @@ There are also data operations you might want to perform on every GPU. Use :meth
|
|||
|
||||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
def setup(self, stage: str):
|
||||
|
||||
# Assign Train/val split(s) for use in Dataloaders
|
||||
if stage in (None, "fit"):
|
||||
if stage == "fit":
|
||||
mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
|
||||
# Assign Test split(s) for use in Dataloaders
|
||||
if stage in (None, "test"):
|
||||
if stage == "test":
|
||||
self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
|
||||
|
||||
|
||||
|
@ -256,8 +256,7 @@ For eg., if you are working with NLP task where you need to tokenize the text an
|
|||
|
||||
|
||||
This method expects a ``stage`` argument.
|
||||
It is used to separate setup logic for ``trainer.{fit,validate,test,predict}``. If ``setup`` is called with ``stage=None``,
|
||||
we assume all stages have been set-up.
|
||||
It is used to separate setup logic for ``trainer.{fit,validate,test,predict}``.
|
||||
|
||||
.. note:: :ref:`setup<data/datamodule:setup>` is called from every process across all the nodes. Setting state here is recommended.
|
||||
.. note:: :ref:`teardown<data/datamodule:teardown>` can be used to clean up the state. It is also called from every process across all the nodes.
|
||||
|
|
|
@ -125,7 +125,7 @@ class ImageNetLightningModel(LightningModule):
|
|||
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch // 30))
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
def setup(self, stage: str):
|
||||
if isinstance(self.trainer.strategy, ParallelStrategy):
|
||||
# When using a single GPU per process and per `DistributedDataParallel`, we need to divide the batch size
|
||||
# ourselves based on the total number of GPUs we have
|
||||
|
@ -133,7 +133,7 @@ class ImageNetLightningModel(LightningModule):
|
|||
self.batch_size = int(self.batch_size / num_processes)
|
||||
self.workers = int(self.workers / num_processes)
|
||||
|
||||
if stage in (None, "fit"):
|
||||
if stage == "fit":
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_dir = os.path.join(self.data_path, "train")
|
||||
self.train_dataset = datasets.ImageFolder(
|
||||
|
|
|
@ -83,7 +83,7 @@ class MNISTKFoldDataModule(BaseKFoldDataModule):
|
|||
# download the data.
|
||||
MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
def setup(self, stage: str) -> None:
|
||||
# load the data
|
||||
dataset = MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
|
||||
self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000])
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, TYPE_CHECKING, Union
|
||||
|
||||
from core.state import ProgressBarState, TrainerState
|
||||
|
||||
|
@ -31,7 +31,7 @@ class PLAppProgressTracker(Callback):
|
|||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
stage: Optional[str] = None,
|
||||
stage: str,
|
||||
) -> None:
|
||||
self.is_enabled = trainer.is_global_zero
|
||||
|
||||
|
@ -261,7 +261,7 @@ class PLAppSummary(Callback):
|
|||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
stage: Optional[str] = None,
|
||||
stage: str,
|
||||
) -> None:
|
||||
self.work.model_hparams = self._sanitize_model_init_args(dict(**pl_module.hparams))
|
||||
|
||||
|
@ -284,7 +284,7 @@ class PLAppArtifactsTracker(Callback):
|
|||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
stage: Optional[str] = None,
|
||||
stage: str,
|
||||
) -> None:
|
||||
log_dir = self._get_logdir(trainer)
|
||||
self.work.log_dir = Path(log_dir) if log_dir is not None else None
|
||||
|
|
|
@ -72,10 +72,10 @@ class Callback:
|
|||
Called before accelerator is being setup.
|
||||
"""
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
"""Called when fit, validate, test, predict, or tune begins."""
|
||||
|
||||
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
"""Called when fit, validate, test, predict, or tune ends."""
|
||||
|
||||
def on_init_start(self, trainer: "pl.Trainer") -> None:
|
||||
|
|
|
@ -58,7 +58,7 @@ class DeviceStatsMonitor(Callback):
|
|||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
stage: Optional[str] = None,
|
||||
stage: str,
|
||||
) -> None:
|
||||
if stage != "fit":
|
||||
return
|
||||
|
|
|
@ -129,7 +129,7 @@ class EarlyStopping(Callback):
|
|||
def state_key(self) -> str:
|
||||
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
if self._check_on_train_epoch_end is None:
|
||||
# if the user runs validation multiple times per training epoch or multiple training epochs without
|
||||
# validation, then we run after validation instead of on train epoch end
|
||||
|
|
|
@ -244,7 +244,7 @@ class BaseFinetuning(Callback):
|
|||
if params:
|
||||
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
self.freeze_before_training(pl_module)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -254,7 +254,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
save_on_train_epoch_end=self._save_on_train_epoch_end,
|
||||
)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
self.__resolve_ckpt_dir(trainer)
|
||||
assert self.dirpath is not None
|
||||
if trainer.is_global_zero and stage == "fit":
|
||||
|
|
|
@ -217,7 +217,7 @@ class ProgressBarBase(Callback):
|
|||
"""You should provide a way to print without breaking the progress bar."""
|
||||
print(*args, **kwargs)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
self._trainer = trainer
|
||||
if not trainer.is_global_zero:
|
||||
self.disable()
|
||||
|
|
|
@ -476,7 +476,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
if self._metric_component:
|
||||
self._metric_component.update(metrics)
|
||||
|
||||
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
|
||||
def teardown(self, trainer, pl_module, stage: str) -> None:
|
||||
self._stop_progress()
|
||||
|
||||
def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
|
||||
|
|
|
@ -361,7 +361,7 @@ class ModelPruning(Callback):
|
|||
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
|
||||
)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
parameters_to_prune = self.sanitize_parameters_to_prune(
|
||||
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
|
||||
)
|
||||
|
|
|
@ -143,7 +143,7 @@ class StochasticWeightAveraging(Callback):
|
|||
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
|
||||
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
# copy the model before moving it to accelerator device.
|
||||
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
|
||||
self._average_model = deepcopy(pl_module)
|
||||
|
|
|
@ -205,7 +205,7 @@ class SaveConfigCallback(Callback):
|
|||
self.overwrite = overwrite
|
||||
self.multifile = multifile
|
||||
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
||||
log_dir = trainer.log_dir # this broadcasts the directory
|
||||
assert log_dir is not None
|
||||
config_path = os.path.join(log_dir, self.config_filename)
|
||||
|
|
|
@ -380,7 +380,7 @@ class DataHooks:
|
|||
model.predict_dataloader()
|
||||
"""
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when
|
||||
you need to build models dynamically or adjust something about them. This hook is called on every process
|
||||
when using DDP.
|
||||
|
@ -406,7 +406,7 @@ class DataHooks:
|
|||
self.l1 = nn.Linear(28, data.num_classes)
|
||||
"""
|
||||
|
||||
def teardown(self, stage: Optional[str] = None) -> None:
|
||||
def teardown(self, stage: str) -> None:
|
||||
"""Called at the end of fit (train + validate), validate, test, or predict.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -163,17 +163,17 @@ class BoringDataModule(LightningDataModule):
|
|||
self.checkpoint_state: Optional[str] = None
|
||||
self.random_full = RandomDataset(32, 64 * 4)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
if stage == "fit" or stage is None:
|
||||
def setup(self, stage: str) -> None:
|
||||
if stage == "fit":
|
||||
self.random_train = Subset(self.random_full, indices=range(64))
|
||||
|
||||
if stage in ("fit", "validate") or stage is None:
|
||||
if stage in ("fit", "validate"):
|
||||
self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
|
||||
|
||||
if stage == "test" or stage is None:
|
||||
if stage == "test":
|
||||
self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
|
||||
|
||||
if stage == "predict" or stage is None:
|
||||
if stage == "predict":
|
||||
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
|
|
|
@ -195,7 +195,7 @@ class MNISTDataModule(LightningDataModule):
|
|||
MNIST(self.data_dir, train=True, download=True)
|
||||
MNIST(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Split the train and valid dataset."""
|
||||
extra = dict(transform=self.default_transforms) if self.default_transforms else {}
|
||||
dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra)
|
||||
|
|
|
@ -78,7 +78,7 @@ class AdvancedProfiler(Profiler):
|
|||
recorded_stats[action_name] = s.getvalue()
|
||||
return self._stats_to_str(recorded_stats)
|
||||
|
||||
def teardown(self, stage: Optional[str] = None) -> None:
|
||||
def teardown(self, stage: Optional[str]) -> None:
|
||||
super().teardown(stage=stage)
|
||||
self.profiled_actions = {}
|
||||
|
||||
|
|
|
@ -148,15 +148,13 @@ class Profiler(ABC):
|
|||
output.append(value)
|
||||
return os.linesep.join(output)
|
||||
|
||||
def setup(
|
||||
self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None
|
||||
) -> None:
|
||||
def setup(self, stage: str, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
|
||||
"""Execute arbitrary pre-profiling set-up steps."""
|
||||
self._stage = stage
|
||||
self._local_rank = local_rank
|
||||
self.dirpath = self.dirpath or log_dir
|
||||
|
||||
def teardown(self, stage: Optional[str] = None) -> None:
|
||||
def teardown(self, stage: Optional[str]) -> None:
|
||||
"""Execute arbitrary post-profiling tear-down steps.
|
||||
|
||||
Closes the currently open file and stream.
|
||||
|
|
|
@ -505,7 +505,7 @@ class PyTorchProfiler(Profiler):
|
|||
self._register.__exit__(None, None, None)
|
||||
self._register = None
|
||||
|
||||
def teardown(self, stage: Optional[str] = None) -> None:
|
||||
def teardown(self, stage: str) -> None:
|
||||
self._delete_profilers()
|
||||
|
||||
for k in list(self._recording_map):
|
||||
|
|
|
@ -56,7 +56,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
|
|||
_check_on_pretrain_routine(model)
|
||||
# TODO: Delete CheckpointHooks off LightningDataModule in v1.8
|
||||
_check_datamodule_checkpoint_hooks(trainer)
|
||||
_check_setup_method(trainer)
|
||||
|
||||
|
||||
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
|
||||
|
@ -309,9 +308,3 @@ def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
|
|||
"`LightningDataModule.on_load_checkpoint` was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
|
||||
)
|
||||
|
||||
|
||||
def _check_setup_method(trainer: "pl.Trainer") -> None:
|
||||
for obj in [trainer.lightning_module, trainer.datamodule] + trainer.callbacks:
|
||||
if is_overridden("setup", obj) and not is_param_in_hook_signature(obj.setup, "stage"):
|
||||
raise MisconfigurationException(f"`{obj.__class__.__name__}.setup` does not have a `stage` argument.")
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -187,7 +186,7 @@ def test_optimization(tmpdir):
|
|||
@RunIf(ipu=True)
|
||||
def test_half_precision(tmpdir):
|
||||
class TestCallback(Callback):
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
||||
assert trainer.strategy.model.precision == 16
|
||||
raise SystemExit
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ def test_main_progress_bar_with_val_check_interval_int():
|
|||
limit_train_batches=train_batches, limit_val_batches=10, val_check_interval=3, check_val_every_n_epoch=None
|
||||
)
|
||||
model = BoringModel()
|
||||
trainer.progress_bar_callback.setup(trainer, model)
|
||||
trainer.progress_bar_callback.setup(trainer, model, stage="fit")
|
||||
trainer.strategy.connect(model)
|
||||
trainer._data_connector.attach_data(model)
|
||||
trainer.reset_train_dataloader()
|
||||
|
|
|
@ -1245,7 +1245,7 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
|
|||
# Case - 2
|
||||
# Make sure that everything runs when dirpath is not initialized explicitly
|
||||
cb_restore = CustomModelCheckpoint()
|
||||
cb_restore.setup(Trainer(), BoringModel())
|
||||
cb_restore.setup(Trainer(), BoringModel(), stage="fit")
|
||||
with pytest.warns(UserWarning, match="The dirpath has changed from*"):
|
||||
cb_restore.load_state_dict(written_ckpt)
|
||||
make_assertions(cb_restore, written_ckpt)
|
||||
|
|
|
@ -117,7 +117,7 @@ def test_hooks_no_recursion_error():
|
|||
def test_helper_boringdatamodule():
|
||||
dm = BoringDataModule()
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
dm.setup("fit")
|
||||
|
||||
|
||||
def test_helper_boringdatamodule_with_verbose_setup():
|
||||
|
@ -140,7 +140,7 @@ def test_dm_init_from_argparse_args(tmpdir):
|
|||
args = parser.parse_args(["--data_dir", str(tmpdir)])
|
||||
dm = BoringDataModule.from_argparse_args(args)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
dm.setup("fit")
|
||||
assert dm.data_dir == args.data_dir == str(tmpdir)
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -42,10 +41,10 @@ class MNISTDataModule(LightningDataModule):
|
|||
self.dataset_cls(self.data_dir, train=True, download=True)
|
||||
self.dataset_cls(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
if stage == "fit" or stage is None:
|
||||
def setup(self, stage: str):
|
||||
if stage == "fit":
|
||||
self.mnist_train = self.dataset_cls(self.data_dir, train=True)
|
||||
if stage == "test" or stage is None:
|
||||
if stage == "test":
|
||||
self.mnist_test = self.dataset_cls(self.data_dir, train=False)
|
||||
|
||||
def train_dataloader(self):
|
||||
|
|
|
@ -860,7 +860,7 @@ def test_no_datamodule_for_hparams(tmpdir):
|
|||
model = SaveHparamsModel({"arg1": 5, "arg2": "abc"})
|
||||
org_model_hparams = copy.deepcopy(model.hparams_initial)
|
||||
data = DataModuleWithoutHparams()
|
||||
data.setup()
|
||||
data.setup("fit")
|
||||
|
||||
mock_logger = _get_mock_logger(tmpdir)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger)
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -42,7 +41,7 @@ def test_precision_plugin(hmp_params):
|
|||
@RunIf(hpu=True)
|
||||
def test_mixed_precision(tmpdir, hmp_params: dict):
|
||||
class TestCallback(Callback):
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
||||
assert trainer.strategy.model.precision == "bf16"
|
||||
raise SystemExit
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
|
@ -94,7 +93,7 @@ def test_ddp_torch_dist_is_available_in_setup(
|
|||
"""Test to ensure torch distributed is available within the setup hook using ddp."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
def setup(self, stage: str) -> None:
|
||||
assert torch.distributed.is_initialized()
|
||||
raise SystemExit()
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ import contextlib
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -263,7 +263,7 @@ def test_deepspeed_auto_batch_size_config_select(mock_deepspeed_distributed, moc
|
|||
return DataLoader(dataset_cls(32, 64))
|
||||
|
||||
class AssertCallback(Callback):
|
||||
def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None:
|
||||
def setup(self, trainer, pl_module, stage: str) -> None:
|
||||
assert isinstance(trainer.strategy, DeepSpeedStrategy)
|
||||
config = trainer.strategy.config
|
||||
|
||||
|
@ -1059,7 +1059,7 @@ def test_deepspeed_setup_train_dataloader(tmpdir):
|
|||
super().__init__()
|
||||
self._setup = False
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
def setup(self, stage: str) -> None:
|
||||
self._setup = True
|
||||
|
||||
def train_dataloader(self):
|
||||
|
|
|
@ -16,8 +16,7 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
|
@ -162,40 +161,6 @@ def test_trainer_manual_optimization_config(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_invalid_setup_method():
|
||||
"""Test error message when `setup` method of `LightningModule` or `LightningDataModule` is not defined
|
||||
correctly."""
|
||||
|
||||
class CustomModel(BoringModel):
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
class CustomDataModule(BoringDataModule):
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
class CustomBoringCallback(Callback):
|
||||
def setup(self, pl_module, trainer):
|
||||
pass
|
||||
|
||||
fit_kwargs = [
|
||||
{"model": CustomModel(), "datamodule": BoringDataModule()},
|
||||
{"model": BoringModel(), "datamodule": CustomDataModule()},
|
||||
]
|
||||
|
||||
for kwargs in fit_kwargs:
|
||||
trainer = Trainer(fast_dev_run=True)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="does not have a `stage` argument"):
|
||||
trainer.fit(**kwargs)
|
||||
|
||||
trainer = Trainer(fast_dev_run=True, callbacks=[CustomBoringCallback()])
|
||||
model = BoringModel()
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="does not have a `stage` argument"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer_kwargs", [{"accelerator": "ipu"}, {"accelerator": "gpu", "strategy": "dp"}])
|
||||
@pytest.mark.parametrize("hook", ["transfer_batch_to_device", "on_after_batch_transfer"])
|
||||
def test_raise_exception_with_batch_transfer_hooks(monkeypatch, hook, trainer_kwargs, tmpdir):
|
||||
|
|
Loading…
Reference in New Issue