From 1c2ecbf70c38912124a384f273090145c1328353 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 8 Apr 2021 15:27:48 +0100 Subject: [PATCH] TPUSpawn + IterableDataset error message (#6875) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- .../plugins/training_type/tpu_spawn.py | 46 +++++++++++- tests/plugins/test_tpu_spawn.py | 74 +++++++++++++++++++ 2 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 tests/plugins/test_tpu_spawn.py diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 6806893512..d546067e88 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,15 +15,17 @@ import io import os import re import time -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING import torch import torch.multiprocessing as mp from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -40,6 +42,11 @@ if _OMEGACONF_AVAILABLE: from omegaconf import DictConfig, ListConfig, OmegaConf +if TYPE_CHECKING: + from torch.nn import Module + from torch.utils.data import DataLoader + + class TPUSpawnPlugin(DDPSpawnPlugin): def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None: @@ -47,7 +54,39 @@ class TPUSpawnPlugin(DDPSpawnPlugin): self.tpu_local_core_rank = 0 self.start_method = None - def setup(self, model: torch.nn.Module) -> torch.nn.Module: + @staticmethod + def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: + if not has_len(dataloader): + raise MisconfigurationException( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." + ) + + @staticmethod + def _validate_patched_dataloaders(model: 'Module') -> None: + """Validate and fail fast if the dataloaders were passed directly to fit. + """ + if hasattr(model, 'train_dataloader') and isinstance(model.train_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader) + + if hasattr(model, 'val_dataloader') and isinstance(model.val_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader) + + if hasattr(model, 'test_dataloader') and isinstance(model.test_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader) + + if hasattr(model, 'predict_dataloader') and isinstance(model.predict_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader) + + def connect(self, model: 'Module') -> None: + TPUSpawnPlugin._validate_patched_dataloaders(model) + return super().connect(model) + + def setup(self, model: 'Module') -> 'Module': self.create_mp_queue() return self.model @@ -64,7 +103,8 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def is_distributed(self): return self.world_size != 1 - def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader: + def process_dataloader(self, dataloader: 'DataLoader') -> MpDeviceLoader: + TPUSpawnPlugin._validate_dataloader(dataloader) device = xm.xla_device() dataloader = MpDeviceLoader(dataloader, device) return dataloader diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py new file mode 100644 index 0000000000..bb587827c3 --- /dev/null +++ b/tests/plugins/test_tpu_spawn.py @@ -0,0 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock + +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning.plugins.training_type import TPUSpawnPlugin +from pytorch_lightning.trainer.connectors.data_connector import DataConnector +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader + + +class BoringModelNoDataloaders(BoringModel): + def train_dataloader(self): + raise NotImplementedError + + def val_dataloader(self): + raise NotImplementedError + + def test_dataloader(self): + raise NotImplementedError + + def predict_dataloader(self): + raise NotImplementedError + + +_loader = DataLoader(RandomDataset(32, 64)) +_loader_no_len = CustomNotImplementedErrorDataloader(_loader) + + +@pytest.mark.parametrize( + "train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders", + [ + (_loader_no_len, None, None, None), + (None, _loader_no_len, None, None), + (None, None, _loader_no_len, None), + (None, None, None, _loader_no_len), + (None, [_loader, _loader_no_len], None, None), + ], +) +def test_error_patched_iterable_dataloaders( + tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders +): + model = BoringModelNoDataloaders() + connector = DataConnector(MagicMock()) + + connector.attach_dataloaders( + model, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + test_dataloaders=test_dataloaders, + predict_dataloaders=predict_dataloaders, + ) + + with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): + TPUSpawnPlugin(MagicMock()).connect(model) + + +def test_error_process_iterable_dataloader(tmpdir): + with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): + TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len)