Fix typing in `pl.overrides.distributed` (#10797)

* fix typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-30 22:22:05 +01:00 committed by GitHub
parent f407a00cec
commit c6b52ef3bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 15 deletions

View File

@ -70,7 +70,6 @@ module = [
"pytorch_lightning.loggers.test_tube",
"pytorch_lightning.loggers.wandb",
"pytorch_lightning.loops.epoch.training_epoch_loop",
"pytorch_lightning.overrides.distributed",
"pytorch_lightning.plugins.environments.lightning_environment",
"pytorch_lightning.plugins.environments.lsf_environment",
"pytorch_lightning.plugins.environments.slurm_environment",

View File

@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, Iterator, List, Optional
from typing import Any, cast, Iterator, List, Optional, Sized, Union
import torch
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
@ -42,11 +43,11 @@ class LightningDistributedModule(_LightningModuleWrapperBase):
super().__init__(pl_module)
def _find_tensors(obj): # pragma: no-cover
r"""
Recursively find all tensors contained in the specified object.
"""
if isinstance(obj, torch.Tensor):
def _find_tensors(
obj: Union[Tensor, list, tuple, dict, Any]
) -> Union[List[Tensor], itertools.chain]: # pragma: no-cover
"""Recursively find all tensors contained in the specified object."""
if isinstance(obj, Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
@ -58,27 +59,26 @@ def _find_tensors(obj): # pragma: no-cover
# In manual_optimization, we need to call reducer prepare_for_backward.
# Note: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any):
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
# `prepare_for_backward` is `DistributedDataParallel` specific.
if not isinstance(model, DistributedDataParallel):
return
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
model.require_forward_param_sync = True # type: ignore[assignment]
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if model.find_unused_parameters:
model.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
model.reducer.prepare_for_backward([])
args = list(_find_tensors(output)) if model.find_unused_parameters else []
reducer = cast(torch._C._distributed_c10d.Reducer, model.reducer)
reducer.prepare_for_backward(args)
else:
model.require_forward_param_sync = False
model.require_forward_param_sync = False # type: ignore[assignment]
class UnrepeatedDistributedSampler(DistributedSampler):
"""A fork of the pytorch DistributedSampler that doesn't repeat data, instead allowing the number of batches
"""A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches
per process to be off-by-one from each other. This makes this sampler usable for predictions (it's
deterministic and doesn't require shuffling). It is potentially unsafe to use this sampler for training,
because during training the DistributedDataParallel syncs buffers on each forward pass, so it could freeze if
@ -91,6 +91,8 @@ class UnrepeatedDistributedSampler(DistributedSampler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if not isinstance(self.dataset, Sized):
raise TypeError("The given dataset must implement the `__len__` method.")
self.num_samples = len(range(self.rank, len(self.dataset), self.num_replicas))
self.total_size = len(self.dataset)
# If any process has at least one batch, every other process needs to
@ -98,6 +100,8 @@ class UnrepeatedDistributedSampler(DistributedSampler):
assert self.num_samples >= 1 or self.total_size == 0
def __iter__(self) -> Iterator[List[int]]:
if not isinstance(self.dataset, Sized):
raise TypeError("The given dataset must implement the `__len__` method.")
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()