Set `dataset` attribute to `MpDeviceLoader` used in TPU Spawn (#10151)

This commit is contained in:
Kaushik B 2021-10-27 01:23:01 +05:30 committed by GitHub
parent 5ade197580
commit c33df2639f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 2 deletions

View File

@ -138,7 +138,10 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
TPUSpawnPlugin._validate_dataloader(dataloader)
return MpDeviceLoader(dataloader, self.root_device)
dataloader = MpDeviceLoader(dataloader, self.root_device)
# Mimic interface to torch.utils.data.DataLoader
dataloader.dataset = dataloader._loader.dataset
return dataloader
def configure_ddp(self) -> None:
pass

View File

@ -18,6 +18,7 @@ from unittest.mock import patch
import pytest
import torch
from torch import nn
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
@ -25,7 +26,7 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO
from pytorch_lightning.utilities import find_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test
@ -300,3 +301,12 @@ def test_tpu_invalid_raises():
def test_xla_checkpoint_plugin_being_default():
trainer = Trainer(tpu_cores=8)
assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO)
@RunIf(tpu=True)
@patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm")
def test_mp_device_dataloader_attribute(_):
dataset = RandomDataset(32, 64)
dataloader = TPUSpawnPlugin().process_dataloader(DataLoader(dataset))
assert dataloader.dataset == dataset