[bugfix] Add support for CombinedLoader in validation with ddp (#7102)
* add test * add changelog * resolve flake8 * remove print
This commit is contained in:
parent
67528c4665
commit
9beec26c3e
|
@ -304,6 +304,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed metric objects passed directly to `self.log` not being reset correctly ([#7055](https://github.com/PyTorchLightning/pytorch-lightning/pull/7055))
|
||||
|
||||
|
||||
- Fixed `CombinedLoader` in distributed settings for validation / testing ([#7102](https://github.com/PyTorchLightning/pytorch-lightning/pull/7102))
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -108,12 +108,15 @@ class TrainerDataLoadingMixin(ABC):
|
|||
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
|
||||
|
||||
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
|
||||
|
||||
# don't do anything if it's not a dataloader
|
||||
is_dataloader = isinstance(dataloader, DataLoader)
|
||||
# don't manipulate iterable datasets
|
||||
is_iterable_ds = has_iterable_dataset(dataloader)
|
||||
|
||||
if isinstance(dataloader, CombinedLoader):
|
||||
dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle)
|
||||
return dataloader
|
||||
|
||||
if not is_dataloader or is_iterable_ds:
|
||||
return dataloader
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@ from typing import Any, Callable, Optional, Union
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
@ -352,7 +354,7 @@ class CombinedLoader(object):
|
|||
@property
|
||||
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
|
||||
"""Return a collections of samplers extracting from loaders."""
|
||||
return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, wrong_dtype=(Sequence, Mapping))
|
||||
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, 'sampler', None)
|
||||
|
||||
def _wrap_loaders_max_size_cycle(self) -> Any:
|
||||
"""
|
||||
|
|
|
@ -11,12 +11,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import Sequence
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.trainer.supporters import (
|
||||
_nested_calc_num_data,
|
||||
CombinedDataset,
|
||||
|
@ -25,6 +31,7 @@ from pytorch_lightning.trainer.supporters import (
|
|||
CycleIterator,
|
||||
TensorRunningAccum,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -237,3 +244,46 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length):
|
|||
calculated_length = _nested_calc_num_data(input_data, compute_func)
|
||||
|
||||
assert calculated_length == expected_length
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
|
||||
@mock.patch('torch.cuda.device_count', return_value=2)
|
||||
@mock.patch('torch.cuda.is_available', return_value=True)
|
||||
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir):
|
||||
"""
|
||||
This test makes sure distributed sampler has been properly injected in dataloaders
|
||||
when using CombinedLoader
|
||||
"""
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
dataloader = CombinedLoader({
|
||||
"a": DataLoader(CustomDataset(range(10))),
|
||||
"b": {
|
||||
"c": DataLoader(CustomDataset(range(10))),
|
||||
"d": DataLoader(CustomDataset(range(10)))
|
||||
},
|
||||
"e": [DataLoader(CustomDataset(range(10))),
|
||||
DataLoader(CustomDataset(range(10)))]
|
||||
})
|
||||
|
||||
trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2)
|
||||
dataloader = trainer.auto_add_sampler(dataloader, shuffle=True)
|
||||
_count = 0
|
||||
|
||||
def _assert_distributed_sampler(v):
|
||||
nonlocal _count
|
||||
_count += 1
|
||||
assert isinstance(v, DistributedSampler)
|
||||
|
||||
apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler)
|
||||
assert _count == 5
|
||||
|
|
Loading…
Reference in New Issue