Enable using any Sampler in distributed environment in Lite (#13646)

This commit is contained in:
otaj 2022-07-14 14:05:25 +02:00 committed by GitHub
parent 35841613c7
commit 9098514ea0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 8 deletions

View File

@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
-

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, DistributedSampler
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
@ -223,13 +223,6 @@ class LightningLite(ABC):
"""
sampler = dataloader.sampler
if replace_sampler and self._requires_distributed_sampler(dataloader):
if not isinstance(sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
"You seem to have configured a sampler in your DataLoader. This will be replaced "
" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
" distributed training. Either remove the sampler from your DataLoader or set"
" `replace_sampler=False` if you want to use your custom sampler."
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)