Fix mypy errors attributed to `pytorch_lightning.trainer.connectors.data_connector.py` (#13806)

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
This commit is contained in:
Justin Goheen 2022-08-26 09:28:27 -04:00 committed by GitHub
parent e2221a0b3e
commit 94e567e6f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 30 deletions

View File

@ -55,7 +55,6 @@ module = [
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",

View File

@ -18,6 +18,7 @@ from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Unio
from torch.utils.data import DataLoader, Dataset, IterableDataset
import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import _load_from_checkpoint
@ -62,7 +63,7 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
def __init__(self) -> None:
super().__init__()
# Pointer to the trainer object
self.trainer = None
self.trainer: Optional["pl.Trainer"] = None
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:

View File

@ -105,7 +105,7 @@ class LightningModule(
self._use_amp: bool = False
# the precision used
self.precision: int = 32
self.precision: Union[int, str] = 32
# optionally can be set by user
self._example_input_array = None
@ -294,6 +294,7 @@ class LightningModule(
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
if self._trainer:
datahook_selector = self._trainer._data_connector._datahook_selector
assert datahook_selector is not None
obj = datahook_selector.get_instance(hook_name)
if isinstance(obj, self.__class__):
trainer_method = self._trainer._call_lightning_module_hook

View File

@ -46,7 +46,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
elif trainer.state.fn == TrainerFn.PREDICTING:
__verify_eval_loop_configuration(trainer, model, "predict")
__verify_batch_transfer_support(trainer, model)
__verify_batch_transfer_support(trainer)
_check_deprecated_callback_hooks(trainer)
# TODO: Delete _check_on_hpc_hooks in v1.8
_check_on_hpc_hooks(model)
@ -149,10 +149,12 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning
raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.")
def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
"""Raise Misconfiguration exception since these hooks are not supported in DP mode."""
batch_transfer_hooks = ("transfer_batch_to_device", "on_after_batch_transfer")
datahook_selector = trainer._data_connector._datahook_selector
assert datahook_selector is not None
for hook in batch_transfer_hooks:
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
if isinstance(trainer.strategy, DataParallelStrategy) and (

View File

@ -14,7 +14,7 @@
import multiprocessing
import os
from dataclasses import dataclass, field
from typing import Any, Collection, List, Optional, Tuple, Union
from typing import Any, Iterable, List, Optional, Tuple, Union
from weakref import proxy
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
@ -55,7 +55,7 @@ class DataConnector:
self._test_dataloader_source = _DataLoaderSource(None, "")
self._predict_dataloader_source = _DataLoaderSource(None, "")
self._datahook_selector = _DataHookSelector(None, None)
self._datahook_selector: Optional[_DataHookSelector] = None
@property
def _should_reload_train_dl(self) -> bool:
@ -230,7 +230,7 @@ class DataConnector:
category=PossibleUserWarning,
)
def _requires_distributed_sampler(self, dataloader) -> bool:
def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
return (
self.trainer._accelerator_connector.replace_sampler_ddp
and self.trainer._accelerator_connector.is_distributed
@ -292,14 +292,18 @@ class DataConnector:
return dataloader
def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler:
def _resolve_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> Union[Sampler, Iterable]:
if self._requires_distributed_sampler(dataloader):
distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs
assert distributed_sampler_kwargs is not None
sampler = self._get_distributed_sampler(
dataloader,
shuffle,
mode=mode,
overfit_batches=self.trainer.overfit_batches,
**self.trainer.distributed_sampler_kwargs,
**distributed_sampler_kwargs,
)
# update docs too once this is resolved
@ -357,7 +361,7 @@ class DataConnector:
dataloaders = self._resolve_overfit_batches(dataloaders, mode)
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
dataloaders = [dataloaders] # type: ignore[assignment]
if any(dl is None for dl in dataloaders):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
@ -426,7 +430,7 @@ class DataConnector:
return loader_num_batches, dataloaders
def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[DataLoader]]:
def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS:
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.
Returns:
@ -447,10 +451,12 @@ class DataConnector:
return dataloader
@staticmethod
def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]:
def _resolve_overfit_batches(
dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], mode: RunningStage
) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
all_have_sequential_sampler = True
def resolve_has_no_sequential_sampler(dataloader: DataLoader):
def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None:
nonlocal all_have_sequential_sampler
all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
dataloader.sampler, SequentialSampler
@ -460,19 +466,23 @@ class DataConnector:
if not all_have_sequential_sampler:
rank_zero_warn(
"You requested to overfit but enabled training dataloader shuffling."
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)
def replace_sampler(dataloader: DataLoader) -> DataLoader:
return _update_dataloader(dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode)
return _update_dataloader(
dataloader,
sampler=SequentialSampler(dataloader.dataset), # type: ignore[arg-type]
mode=mode,
)
dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler)
return dataloaders
@staticmethod
def _check_eval_shuffling(dataloader, mode):
def _check_eval_shuffling(dataloader: DataLoader, mode: RunningStage) -> None:
# limit this warning only for samplers assigned automatically when shuffle is set
if _is_dataloader_shuffled(dataloader):
rank_zero_warn(
@ -506,18 +516,14 @@ class _DataLoaderSource:
If the source is a module, the method with the corresponding :attr:`name` gets called.
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
if not self.name:
return self.instance
if isinstance(self.instance, LightningModule):
if isinstance(self.instance, pl.LightningModule):
return self.instance.trainer._call_lightning_module_hook(self.name, pl_module=self.instance)
if isinstance(self.instance, LightningDataModule):
if isinstance(self.instance, pl.LightningDataModule):
method = getattr(self.instance, self.name)
return method()
assert self.instance is not None
return self.instance
def is_defined(self) -> bool:
@ -532,9 +538,7 @@ class _DataLoaderSource:
It does not check whether ``*_dataloader`` methods are actually overridden.
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
return isinstance(self.instance, (LightningModule, LightningDataModule))
return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule))
@dataclass
@ -553,7 +557,7 @@ class _DataHookSelector:
model: "pl.LightningModule"
datamodule: Optional["pl.LightningDataModule"]
_valid_hooks: Tuple[str] = field(
_valid_hooks: Tuple[str, ...] = field(
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
)

View File

@ -2234,7 +2234,7 @@ class Trainer(
return self.strategy.is_global_zero
@property
def distributed_sampler_kwargs(self) -> Optional[dict]:
def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]:
if isinstance(self.strategy, ParallelStrategy):
return self.strategy.distributed_sampler_kwargs

View File

@ -66,7 +66,7 @@ def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"):
with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"):
trainer.fit(model)
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)