TPUSpawn + IterableDataset error message (#6875)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
87f0aeac25
commit
1c2ecbf70c
|
@ -15,15 +15,17 @@ import io
|
|||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.data import has_len
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
@ -40,6 +42,11 @@ if _OMEGACONF_AVAILABLE:
|
|||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class TPUSpawnPlugin(DDPSpawnPlugin):
|
||||
|
||||
def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
|
||||
|
@ -47,7 +54,39 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.tpu_local_core_rank = 0
|
||||
self.start_method = None
|
||||
|
||||
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
|
||||
@staticmethod
|
||||
def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']):
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
for dataloader in dataloaders:
|
||||
if not has_len(dataloader):
|
||||
raise MisconfigurationException(
|
||||
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
|
||||
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_patched_dataloaders(model: 'Module') -> None:
|
||||
"""Validate and fail fast if the dataloaders were passed directly to fit.
|
||||
"""
|
||||
if hasattr(model, 'train_dataloader') and isinstance(model.train_dataloader, _PatchDataLoader):
|
||||
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)
|
||||
|
||||
if hasattr(model, 'val_dataloader') and isinstance(model.val_dataloader, _PatchDataLoader):
|
||||
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)
|
||||
|
||||
if hasattr(model, 'test_dataloader') and isinstance(model.test_dataloader, _PatchDataLoader):
|
||||
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)
|
||||
|
||||
if hasattr(model, 'predict_dataloader') and isinstance(model.predict_dataloader, _PatchDataLoader):
|
||||
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)
|
||||
|
||||
def connect(self, model: 'Module') -> None:
|
||||
TPUSpawnPlugin._validate_patched_dataloaders(model)
|
||||
return super().connect(model)
|
||||
|
||||
def setup(self, model: 'Module') -> 'Module':
|
||||
self.create_mp_queue()
|
||||
return self.model
|
||||
|
||||
|
@ -64,7 +103,8 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def is_distributed(self):
|
||||
return self.world_size != 1
|
||||
|
||||
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
|
||||
def process_dataloader(self, dataloader: 'DataLoader') -> MpDeviceLoader:
|
||||
TPUSpawnPlugin._validate_dataloader(dataloader)
|
||||
device = xm.xla_device()
|
||||
dataloader = MpDeviceLoader(dataloader, device)
|
||||
return dataloader
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
|
||||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader
|
||||
|
||||
|
||||
class BoringModelNoDataloaders(BoringModel):
|
||||
def train_dataloader(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def val_dataloader(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_dataloader(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_dataloader(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
_loader = DataLoader(RandomDataset(32, 64))
|
||||
_loader_no_len = CustomNotImplementedErrorDataloader(_loader)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders",
|
||||
[
|
||||
(_loader_no_len, None, None, None),
|
||||
(None, _loader_no_len, None, None),
|
||||
(None, None, _loader_no_len, None),
|
||||
(None, None, None, _loader_no_len),
|
||||
(None, [_loader, _loader_no_len], None, None),
|
||||
],
|
||||
)
|
||||
def test_error_patched_iterable_dataloaders(
|
||||
tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders
|
||||
):
|
||||
model = BoringModelNoDataloaders()
|
||||
connector = DataConnector(MagicMock())
|
||||
|
||||
connector.attach_dataloaders(
|
||||
model,
|
||||
train_dataloader=train_dataloader,
|
||||
val_dataloaders=val_dataloaders,
|
||||
test_dataloaders=test_dataloaders,
|
||||
predict_dataloaders=predict_dataloaders,
|
||||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
|
||||
TPUSpawnPlugin(MagicMock()).connect(model)
|
||||
|
||||
|
||||
def test_error_process_iterable_dataloader(tmpdir):
|
||||
with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
|
||||
TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len)
|
Loading…
Reference in New Issue