From 94e567e6f0dbe3e06feebc8fdf4bdc807182fa96 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 26 Aug 2022 09:28:27 -0400 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning.trainer.connectors.data_connector.py` (#13806) Co-authored-by: rohitgr7 Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> --- pyproject.toml | 1 - src/pytorch_lightning/core/datamodule.py | 3 +- src/pytorch_lightning/core/module.py | 3 +- .../trainer/configuration_validator.py | 6 ++- .../trainer/connectors/data_connector.py | 50 ++++++++++--------- src/pytorch_lightning/trainer/trainer.py | 2 +- .../trainer/flags/test_overfit_batches.py | 2 +- 7 files changed, 37 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 18ae22feeb..cb5274d577 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ module = [ "pytorch_lightning.profilers.pytorch", "pytorch_lightning.strategies.sharded", "pytorch_lightning.trainer.callback_hook", - "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", diff --git a/src/pytorch_lightning/core/datamodule.py b/src/pytorch_lightning/core/datamodule.py index 4edde3fe6a..e730aabd8c 100644 --- a/src/pytorch_lightning/core/datamodule.py +++ b/src/pytorch_lightning/core/datamodule.py @@ -18,6 +18,7 @@ from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Unio from torch.utils.data import DataLoader, Dataset, IterableDataset +import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import _load_from_checkpoint @@ -62,7 +63,7 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): def __init__(self) -> None: super().__init__() # Pointer to the trainer object - self.trainer = None + self.trainer: Optional["pl.Trainer"] = None @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 0926cc52ec..a479beadc7 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -105,7 +105,7 @@ class LightningModule( self._use_amp: bool = False # the precision used - self.precision: int = 32 + self.precision: Union[int, str] = 32 # optionally can be set by user self._example_input_array = None @@ -294,6 +294,7 @@ class LightningModule( def _call_batch_hook(self, hook_name: str, *args: Any) -> Any: if self._trainer: datahook_selector = self._trainer._data_connector._datahook_selector + assert datahook_selector is not None obj = datahook_selector.get_instance(hook_name) if isinstance(obj, self.__class__): trainer_method = self._trainer._call_lightning_module_hook diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index 74c477b245..9d277f5ac4 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -46,7 +46,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: elif trainer.state.fn == TrainerFn.PREDICTING: __verify_eval_loop_configuration(trainer, model, "predict") - __verify_batch_transfer_support(trainer, model) + __verify_batch_transfer_support(trainer) _check_deprecated_callback_hooks(trainer) # TODO: Delete _check_on_hpc_hooks in v1.8 _check_on_hpc_hooks(model) @@ -149,10 +149,12 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.") -def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: +def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None: """Raise Misconfiguration exception since these hooks are not supported in DP mode.""" batch_transfer_hooks = ("transfer_batch_to_device", "on_after_batch_transfer") datahook_selector = trainer._data_connector._datahook_selector + assert datahook_selector is not None + for hook in batch_transfer_hooks: # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. if isinstance(trainer.strategy, DataParallelStrategy) and ( diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index dae4e4a352..d866584394 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, Collection, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler @@ -55,7 +55,7 @@ class DataConnector: self._test_dataloader_source = _DataLoaderSource(None, "") self._predict_dataloader_source = _DataLoaderSource(None, "") - self._datahook_selector = _DataHookSelector(None, None) + self._datahook_selector: Optional[_DataHookSelector] = None @property def _should_reload_train_dl(self) -> bool: @@ -230,7 +230,7 @@ class DataConnector: category=PossibleUserWarning, ) - def _requires_distributed_sampler(self, dataloader) -> bool: + def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: return ( self.trainer._accelerator_connector.replace_sampler_ddp and self.trainer._accelerator_connector.is_distributed @@ -292,14 +292,18 @@ class DataConnector: return dataloader - def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: + def _resolve_sampler( + self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None + ) -> Union[Sampler, Iterable]: if self._requires_distributed_sampler(dataloader): + distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs + assert distributed_sampler_kwargs is not None sampler = self._get_distributed_sampler( dataloader, shuffle, mode=mode, overfit_batches=self.trainer.overfit_batches, - **self.trainer.distributed_sampler_kwargs, + **distributed_sampler_kwargs, ) # update docs too once this is resolved @@ -357,7 +361,7 @@ class DataConnector: dataloaders = self._resolve_overfit_batches(dataloaders, mode) if not isinstance(dataloaders, list): - dataloaders = [dataloaders] + dataloaders = [dataloaders] # type: ignore[assignment] if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") @@ -426,7 +430,7 @@ class DataConnector: return loader_num_batches, dataloaders - def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[DataLoader]]: + def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS: """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. Returns: @@ -447,10 +451,12 @@ class DataConnector: return dataloader @staticmethod - def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]: + def _resolve_overfit_batches( + dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], mode: RunningStage + ) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: all_have_sequential_sampler = True - def resolve_has_no_sequential_sampler(dataloader: DataLoader): + def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: nonlocal all_have_sequential_sampler all_have_sequential_sampler = all_have_sequential_sampler & isinstance( dataloader.sampler, SequentialSampler @@ -460,19 +466,23 @@ class DataConnector: if not all_have_sequential_sampler: rank_zero_warn( - "You requested to overfit but enabled training dataloader shuffling." + f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) def replace_sampler(dataloader: DataLoader) -> DataLoader: - return _update_dataloader(dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode) + return _update_dataloader( + dataloader, + sampler=SequentialSampler(dataloader.dataset), # type: ignore[arg-type] + mode=mode, + ) dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler) return dataloaders @staticmethod - def _check_eval_shuffling(dataloader, mode): + def _check_eval_shuffling(dataloader: DataLoader, mode: RunningStage) -> None: # limit this warning only for samplers assigned automatically when shuffle is set if _is_dataloader_shuffled(dataloader): rank_zero_warn( @@ -506,18 +516,14 @@ class _DataLoaderSource: If the source is a module, the method with the corresponding :attr:`name` gets called. """ - from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import - - if not self.name: - return self.instance - - if isinstance(self.instance, LightningModule): + if isinstance(self.instance, pl.LightningModule): return self.instance.trainer._call_lightning_module_hook(self.name, pl_module=self.instance) - if isinstance(self.instance, LightningDataModule): + if isinstance(self.instance, pl.LightningDataModule): method = getattr(self.instance, self.name) return method() + assert self.instance is not None return self.instance def is_defined(self) -> bool: @@ -532,9 +538,7 @@ class _DataLoaderSource: It does not check whether ``*_dataloader`` methods are actually overridden. """ - from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import - - return isinstance(self.instance, (LightningModule, LightningDataModule)) + return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule)) @dataclass @@ -553,7 +557,7 @@ class _DataHookSelector: model: "pl.LightningModule" datamodule: Optional["pl.LightningDataModule"] - _valid_hooks: Tuple[str] = field( + _valid_hooks: Tuple[str, ...] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 8bee0ac6df..d5bcec7db8 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -2234,7 +2234,7 @@ class Trainer( return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Optional[dict]: + def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index da3e154349..dc73e76cc3 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -66,7 +66,7 @@ def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir): model = TestModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) - with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"): + with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"): trainer.fit(model) assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)