diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 94edfe5354..0a18b7a247 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -138,7 +138,10 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: TPUSpawnPlugin._validate_dataloader(dataloader) - return MpDeviceLoader(dataloader, self.root_device) + dataloader = MpDeviceLoader(dataloader, self.root_device) + # Mimic interface to torch.utils.data.DataLoader + dataloader.dataset = dataloader._loader.dataset + return dataloader def configure_ddp(self) -> None: pass diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 98a96d15db..4b2f56c329 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -18,6 +18,7 @@ from unittest.mock import patch import pytest import torch from torch import nn +from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator @@ -25,7 +26,7 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel +from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test @@ -300,3 +301,12 @@ def test_tpu_invalid_raises(): def test_xla_checkpoint_plugin_being_default(): trainer = Trainer(tpu_cores=8) assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO) + + +@RunIf(tpu=True) +@patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") +def test_mp_device_dataloader_attribute(_): + dataset = RandomDataset(32, 64) + dataloader = TPUSpawnPlugin().process_dataloader(DataLoader(dataset)) + + assert dataloader.dataset == dataset