Enable using any Sampler in distributed environment in Lite (#13646)
This commit is contained in:
parent
35841613c7
commit
9098514ea0
|
@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
|
||||
|
||||
|
||||
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
|
||||
|
||||
|
||||
-
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
|
@ -223,13 +223,6 @@ class LightningLite(ABC):
|
|||
"""
|
||||
sampler = dataloader.sampler
|
||||
if replace_sampler and self._requires_distributed_sampler(dataloader):
|
||||
if not isinstance(sampler, (SequentialSampler, RandomSampler)):
|
||||
raise MisconfigurationException(
|
||||
"You seem to have configured a sampler in your DataLoader. This will be replaced "
|
||||
" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
|
||||
" distributed training. Either remove the sampler from your DataLoader or set"
|
||||
" `replace_sampler=False` if you want to use your custom sampler."
|
||||
)
|
||||
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
|
||||
|
||||
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
|
||||
|
|
Loading…
Reference in New Issue