diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6856c4b21c..b34d2a2f9b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604)) +- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646)) + + - diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 4dfcde177f..f5cebcd89a 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, DistributedSampler from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -223,13 +223,6 @@ class LightningLite(ABC): """ sampler = dataloader.sampler if replace_sampler and self._requires_distributed_sampler(dataloader): - if not isinstance(sampler, (SequentialSampler, RandomSampler)): - raise MisconfigurationException( - "You seem to have configured a sampler in your DataLoader. This will be replaced " - " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" - " distributed training. Either remove the sampler from your DataLoader or set" - " `replace_sampler=False` if you want to use your custom sampler." - ) sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)