Set `dataset` attribute to `MpDeviceLoader` used in TPU Spawn (#10151)
This commit is contained in:
parent
5ade197580
commit
c33df2639f
|
@ -138,7 +138,10 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
|
||||
def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
|
||||
TPUSpawnPlugin._validate_dataloader(dataloader)
|
||||
return MpDeviceLoader(dataloader, self.root_device)
|
||||
dataloader = MpDeviceLoader(dataloader, self.root_device)
|
||||
# Mimic interface to torch.utils.data.DataLoader
|
||||
dataloader.dataset = dataloader._loader.dataset
|
||||
return dataloader
|
||||
|
||||
def configure_ddp(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -18,6 +18,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
|
@ -25,7 +26,7 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator
|
|||
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO
|
||||
from pytorch_lightning.utilities import find_shared_parameters
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
from tests.helpers.utils import pl_multi_process_test
|
||||
|
||||
|
@ -300,3 +301,12 @@ def test_tpu_invalid_raises():
|
|||
def test_xla_checkpoint_plugin_being_default():
|
||||
trainer = Trainer(tpu_cores=8)
|
||||
assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO)
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
@patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm")
|
||||
def test_mp_device_dataloader_attribute(_):
|
||||
dataset = RandomDataset(32, 64)
|
||||
dataloader = TPUSpawnPlugin().process_dataloader(DataLoader(dataset))
|
||||
|
||||
assert dataloader.dataset == dataset
|
||||
|
|
Loading…
Reference in New Issue