Enable using any Sampler in distributed environment in Lite (#13646)

This commit is contained in:
otaj 2022-07-14 14:05:25 +02:00 committed by GitHub
parent 35841613c7
commit 9098514ea0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 8 deletions

View File

@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604)) - The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
- -

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, DistributedSampler
from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
@ -223,13 +223,6 @@ class LightningLite(ABC):
""" """
sampler = dataloader.sampler sampler = dataloader.sampler
if replace_sampler and self._requires_distributed_sampler(dataloader): if replace_sampler and self._requires_distributed_sampler(dataloader):
if not isinstance(sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
"You seem to have configured a sampler in your DataLoader. This will be replaced "
" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
" distributed training. Either remove the sampler from your DataLoader or set"
" `replace_sampler=False` if you want to use your custom sampler."
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler) # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)