From c6b52ef3bdbf35c815351e704001e296e5491a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 22:22:05 +0100 Subject: [PATCH] Fix typing in `pl.overrides.distributed` (#10797) * fix typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 - pytorch_lightning/overrides/distributed.py | 32 ++++++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b491e0d691..b44152ac9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ module = [ "pytorch_lightning.loggers.test_tube", "pytorch_lightning.loggers.wandb", "pytorch_lightning.loops.epoch.training_epoch_loop", - "pytorch_lightning.overrides.distributed", "pytorch_lightning.plugins.environments.lightning_environment", "pytorch_lightning.plugins.environments.lsf_environment", "pytorch_lightning.plugins.environments.slurm_environment", diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 0cf392dd44..f7c2a71b49 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Iterator, List, Optional +from typing import Any, cast, Iterator, List, Optional, Sized, Union import torch +from torch import Tensor from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler @@ -42,11 +43,11 @@ class LightningDistributedModule(_LightningModuleWrapperBase): super().__init__(pl_module) -def _find_tensors(obj): # pragma: no-cover - r""" - Recursively find all tensors contained in the specified object. - """ - if isinstance(obj, torch.Tensor): +def _find_tensors( + obj: Union[Tensor, list, tuple, dict, Any] +) -> Union[List[Tensor], itertools.chain]: # pragma: no-cover + """Recursively find all tensors contained in the specified object.""" + if isinstance(obj, Tensor): return [obj] if isinstance(obj, (list, tuple)): return itertools.chain(*map(_find_tensors, obj)) @@ -58,27 +59,26 @@ def _find_tensors(obj): # pragma: no-cover # In manual_optimization, we need to call reducer prepare_for_backward. # Note: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 -def prepare_for_backward(model: DistributedDataParallel, output: Any): +def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None: # `prepare_for_backward` is `DistributedDataParallel` specific. if not isinstance(model, DistributedDataParallel): return if torch.is_grad_enabled() and model.require_backward_grad_sync: - model.require_forward_param_sync = True + model.require_forward_param_sync = True # type: ignore[assignment] # We'll return the output object verbatim since it is a freeform # object. We need to find any tensors in this object, though, # because we need to figure out which parameters were used during # this forward pass, to ensure we short circuit reduction for any # unused parameters. Only if `find_unused_parameters` is set. - if model.find_unused_parameters: - model.reducer.prepare_for_backward(list(_find_tensors(output))) - else: - model.reducer.prepare_for_backward([]) + args = list(_find_tensors(output)) if model.find_unused_parameters else [] + reducer = cast(torch._C._distributed_c10d.Reducer, model.reducer) + reducer.prepare_for_backward(args) else: - model.require_forward_param_sync = False + model.require_forward_param_sync = False # type: ignore[assignment] class UnrepeatedDistributedSampler(DistributedSampler): - """A fork of the pytorch DistributedSampler that doesn't repeat data, instead allowing the number of batches + """A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches per process to be off-by-one from each other. This makes this sampler usable for predictions (it's deterministic and doesn't require shuffling). It is potentially unsafe to use this sampler for training, because during training the DistributedDataParallel syncs buffers on each forward pass, so it could freeze if @@ -91,6 +91,8 @@ class UnrepeatedDistributedSampler(DistributedSampler): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) + if not isinstance(self.dataset, Sized): + raise TypeError("The given dataset must implement the `__len__` method.") self.num_samples = len(range(self.rank, len(self.dataset), self.num_replicas)) self.total_size = len(self.dataset) # If any process has at least one batch, every other process needs to @@ -98,6 +100,8 @@ class UnrepeatedDistributedSampler(DistributedSampler): assert self.num_samples >= 1 or self.total_size == 0 def __iter__(self) -> Iterator[List[int]]: + if not isinstance(self.dataset, Sized): + raise TypeError("The given dataset must implement the `__len__` method.") if self.shuffle: # deterministically shuffle based on epoch g = torch.Generator()