Enable using any Sampler in distributed environment in Lite (#13646)
This commit is contained in:
parent
35841613c7
commit
9098514ea0
|
@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
|
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
|
||||||
|
|
||||||
|
|
||||||
|
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
|
||||||
|
|
||||||
|
|
||||||
-
|
-
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
|
||||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||||
|
@ -223,13 +223,6 @@ class LightningLite(ABC):
|
||||||
"""
|
"""
|
||||||
sampler = dataloader.sampler
|
sampler = dataloader.sampler
|
||||||
if replace_sampler and self._requires_distributed_sampler(dataloader):
|
if replace_sampler and self._requires_distributed_sampler(dataloader):
|
||||||
if not isinstance(sampler, (SequentialSampler, RandomSampler)):
|
|
||||||
raise MisconfigurationException(
|
|
||||||
"You seem to have configured a sampler in your DataLoader. This will be replaced "
|
|
||||||
" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
|
|
||||||
" distributed training. Either remove the sampler from your DataLoader or set"
|
|
||||||
" `replace_sampler=False` if you want to use your custom sampler."
|
|
||||||
)
|
|
||||||
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
|
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
|
||||||
|
|
||||||
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
|
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
|
||||||
|
|
Loading…
Reference in New Issue