Fix mypy errors attributed to `pytorch_lightning.trainer.connectors.data_connector.py` (#13806)
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
This commit is contained in:
parent
e2221a0b3e
commit
94e567e6f0
|
@ -55,7 +55,6 @@ module = [
|
|||
"pytorch_lightning.profilers.pytorch",
|
||||
"pytorch_lightning.strategies.sharded",
|
||||
"pytorch_lightning.trainer.callback_hook",
|
||||
"pytorch_lightning.trainer.connectors.data_connector",
|
||||
"pytorch_lightning.trainer.supporters",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Unio
|
|||
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
from pytorch_lightning.core.saving import _load_from_checkpoint
|
||||
|
@ -62,7 +63,7 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Pointer to the trainer object
|
||||
self.trainer = None
|
||||
self.trainer: Optional["pl.Trainer"] = None
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
|
||||
|
|
|
@ -105,7 +105,7 @@ class LightningModule(
|
|||
self._use_amp: bool = False
|
||||
|
||||
# the precision used
|
||||
self.precision: int = 32
|
||||
self.precision: Union[int, str] = 32
|
||||
|
||||
# optionally can be set by user
|
||||
self._example_input_array = None
|
||||
|
@ -294,6 +294,7 @@ class LightningModule(
|
|||
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
|
||||
if self._trainer:
|
||||
datahook_selector = self._trainer._data_connector._datahook_selector
|
||||
assert datahook_selector is not None
|
||||
obj = datahook_selector.get_instance(hook_name)
|
||||
if isinstance(obj, self.__class__):
|
||||
trainer_method = self._trainer._call_lightning_module_hook
|
||||
|
|
|
@ -46,7 +46,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
|
|||
elif trainer.state.fn == TrainerFn.PREDICTING:
|
||||
__verify_eval_loop_configuration(trainer, model, "predict")
|
||||
|
||||
__verify_batch_transfer_support(trainer, model)
|
||||
__verify_batch_transfer_support(trainer)
|
||||
_check_deprecated_callback_hooks(trainer)
|
||||
# TODO: Delete _check_on_hpc_hooks in v1.8
|
||||
_check_on_hpc_hooks(model)
|
||||
|
@ -149,10 +149,12 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning
|
|||
raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.")
|
||||
|
||||
|
||||
def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
|
||||
def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
|
||||
"""Raise Misconfiguration exception since these hooks are not supported in DP mode."""
|
||||
batch_transfer_hooks = ("transfer_batch_to_device", "on_after_batch_transfer")
|
||||
datahook_selector = trainer._data_connector._datahook_selector
|
||||
assert datahook_selector is not None
|
||||
|
||||
for hook in batch_transfer_hooks:
|
||||
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
|
||||
if isinstance(trainer.strategy, DataParallelStrategy) and (
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Collection, List, Optional, Tuple, Union
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
from weakref import proxy
|
||||
|
||||
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
|
||||
|
@ -55,7 +55,7 @@ class DataConnector:
|
|||
self._test_dataloader_source = _DataLoaderSource(None, "")
|
||||
self._predict_dataloader_source = _DataLoaderSource(None, "")
|
||||
|
||||
self._datahook_selector = _DataHookSelector(None, None)
|
||||
self._datahook_selector: Optional[_DataHookSelector] = None
|
||||
|
||||
@property
|
||||
def _should_reload_train_dl(self) -> bool:
|
||||
|
@ -230,7 +230,7 @@ class DataConnector:
|
|||
category=PossibleUserWarning,
|
||||
)
|
||||
|
||||
def _requires_distributed_sampler(self, dataloader) -> bool:
|
||||
def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
|
||||
return (
|
||||
self.trainer._accelerator_connector.replace_sampler_ddp
|
||||
and self.trainer._accelerator_connector.is_distributed
|
||||
|
@ -292,14 +292,18 @@ class DataConnector:
|
|||
|
||||
return dataloader
|
||||
|
||||
def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler:
|
||||
def _resolve_sampler(
|
||||
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
|
||||
) -> Union[Sampler, Iterable]:
|
||||
if self._requires_distributed_sampler(dataloader):
|
||||
distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs
|
||||
assert distributed_sampler_kwargs is not None
|
||||
sampler = self._get_distributed_sampler(
|
||||
dataloader,
|
||||
shuffle,
|
||||
mode=mode,
|
||||
overfit_batches=self.trainer.overfit_batches,
|
||||
**self.trainer.distributed_sampler_kwargs,
|
||||
**distributed_sampler_kwargs,
|
||||
)
|
||||
|
||||
# update docs too once this is resolved
|
||||
|
@ -357,7 +361,7 @@ class DataConnector:
|
|||
dataloaders = self._resolve_overfit_batches(dataloaders, mode)
|
||||
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
dataloaders = [dataloaders] # type: ignore[assignment]
|
||||
|
||||
if any(dl is None for dl in dataloaders):
|
||||
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
|
||||
|
@ -426,7 +430,7 @@ class DataConnector:
|
|||
|
||||
return loader_num_batches, dataloaders
|
||||
|
||||
def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[DataLoader]]:
|
||||
def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS:
|
||||
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.
|
||||
|
||||
Returns:
|
||||
|
@ -447,10 +451,12 @@ class DataConnector:
|
|||
return dataloader
|
||||
|
||||
@staticmethod
|
||||
def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]:
|
||||
def _resolve_overfit_batches(
|
||||
dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], mode: RunningStage
|
||||
) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
|
||||
all_have_sequential_sampler = True
|
||||
|
||||
def resolve_has_no_sequential_sampler(dataloader: DataLoader):
|
||||
def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None:
|
||||
nonlocal all_have_sequential_sampler
|
||||
all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
|
||||
dataloader.sampler, SequentialSampler
|
||||
|
@ -460,19 +466,23 @@ class DataConnector:
|
|||
|
||||
if not all_have_sequential_sampler:
|
||||
rank_zero_warn(
|
||||
"You requested to overfit but enabled training dataloader shuffling."
|
||||
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
|
||||
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
|
||||
)
|
||||
|
||||
def replace_sampler(dataloader: DataLoader) -> DataLoader:
|
||||
return _update_dataloader(dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode)
|
||||
return _update_dataloader(
|
||||
dataloader,
|
||||
sampler=SequentialSampler(dataloader.dataset), # type: ignore[arg-type]
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler)
|
||||
|
||||
return dataloaders
|
||||
|
||||
@staticmethod
|
||||
def _check_eval_shuffling(dataloader, mode):
|
||||
def _check_eval_shuffling(dataloader: DataLoader, mode: RunningStage) -> None:
|
||||
# limit this warning only for samplers assigned automatically when shuffle is set
|
||||
if _is_dataloader_shuffled(dataloader):
|
||||
rank_zero_warn(
|
||||
|
@ -506,18 +516,14 @@ class _DataLoaderSource:
|
|||
|
||||
If the source is a module, the method with the corresponding :attr:`name` gets called.
|
||||
"""
|
||||
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
|
||||
|
||||
if not self.name:
|
||||
return self.instance
|
||||
|
||||
if isinstance(self.instance, LightningModule):
|
||||
if isinstance(self.instance, pl.LightningModule):
|
||||
return self.instance.trainer._call_lightning_module_hook(self.name, pl_module=self.instance)
|
||||
|
||||
if isinstance(self.instance, LightningDataModule):
|
||||
if isinstance(self.instance, pl.LightningDataModule):
|
||||
method = getattr(self.instance, self.name)
|
||||
return method()
|
||||
|
||||
assert self.instance is not None
|
||||
return self.instance
|
||||
|
||||
def is_defined(self) -> bool:
|
||||
|
@ -532,9 +538,7 @@ class _DataLoaderSource:
|
|||
|
||||
It does not check whether ``*_dataloader`` methods are actually overridden.
|
||||
"""
|
||||
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
|
||||
|
||||
return isinstance(self.instance, (LightningModule, LightningDataModule))
|
||||
return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule))
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -553,7 +557,7 @@ class _DataHookSelector:
|
|||
|
||||
model: "pl.LightningModule"
|
||||
datamodule: Optional["pl.LightningDataModule"]
|
||||
_valid_hooks: Tuple[str] = field(
|
||||
_valid_hooks: Tuple[str, ...] = field(
|
||||
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
|
||||
)
|
||||
|
||||
|
|
|
@ -2234,7 +2234,7 @@ class Trainer(
|
|||
return self.strategy.is_global_zero
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Optional[dict]:
|
||||
def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]:
|
||||
if isinstance(self.strategy, ParallelStrategy):
|
||||
return self.strategy.distributed_sampler_kwargs
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir):
|
|||
model = TestModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
|
||||
|
||||
with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"):
|
||||
with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"):
|
||||
trainer.fit(model)
|
||||
|
||||
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
|
||||
|
|
Loading…
Reference in New Issue