TPUSpawn + IterableDataset error message (#6875)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Ethan Harris 2021-04-08 15:27:48 +01:00 committed by GitHub
parent 87f0aeac25
commit 1c2ecbf70c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 3 deletions

View File

@ -15,15 +15,17 @@ import io
import os
import re
import time
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
import torch
import torch.multiprocessing as mp
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
@ -40,6 +42,11 @@ if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf
if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader
class TPUSpawnPlugin(DDPSpawnPlugin):
def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
@ -47,7 +54,39 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.tpu_local_core_rank = 0
self.start_method = None
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
@staticmethod
def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']):
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
for dataloader in dataloaders:
if not has_len(dataloader):
raise MisconfigurationException(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)
@staticmethod
def _validate_patched_dataloaders(model: 'Module') -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit.
"""
if hasattr(model, 'train_dataloader') and isinstance(model.train_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)
if hasattr(model, 'val_dataloader') and isinstance(model.val_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)
if hasattr(model, 'test_dataloader') and isinstance(model.test_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)
if hasattr(model, 'predict_dataloader') and isinstance(model.predict_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)
def connect(self, model: 'Module') -> None:
TPUSpawnPlugin._validate_patched_dataloaders(model)
return super().connect(model)
def setup(self, model: 'Module') -> 'Module':
self.create_mp_queue()
return self.model
@ -64,7 +103,8 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def is_distributed(self):
return self.world_size != 1
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
def process_dataloader(self, dataloader: 'DataLoader') -> MpDeviceLoader:
TPUSpawnPlugin._validate_dataloader(dataloader)
device = xm.xla_device()
dataloader = MpDeviceLoader(dataloader, device)
return dataloader

View File

@ -0,0 +1,74 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
import pytest
from torch.utils.data import DataLoader
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader
class BoringModelNoDataloaders(BoringModel):
def train_dataloader(self):
raise NotImplementedError
def val_dataloader(self):
raise NotImplementedError
def test_dataloader(self):
raise NotImplementedError
def predict_dataloader(self):
raise NotImplementedError
_loader = DataLoader(RandomDataset(32, 64))
_loader_no_len = CustomNotImplementedErrorDataloader(_loader)
@pytest.mark.parametrize(
"train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders",
[
(_loader_no_len, None, None, None),
(None, _loader_no_len, None, None),
(None, None, _loader_no_len, None),
(None, None, None, _loader_no_len),
(None, [_loader, _loader_no_len], None, None),
],
)
def test_error_patched_iterable_dataloaders(
tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders
):
model = BoringModelNoDataloaders()
connector = DataConnector(MagicMock())
connector.attach_dataloaders(
model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
)
with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnPlugin(MagicMock()).connect(model)
def test_error_process_iterable_dataloader(tmpdir):
with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len)