diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index ad3b75b59f..d36e874cba 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -238,18 +238,15 @@ class LightningLite(ABC): ) sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) - dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) - try: - dataloader = type(dataloader)(**dataloader_kwargs) - except TypeError: - dataloader_kwargs.pop("dataset") - dataloader = type(dataloader)(**dataloader_kwargs) + # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler) + dataloader = TrainerDataLoadingMixin._update_dataloader(dataloader, sampler) + # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) - return _LiteDataLoader( - dataloader=self._strategy.process_dataloader(dataloader), - device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None, - ) + + dataloader = self._strategy.process_dataloader(dataloader) + device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None + return _LiteDataLoader(dataloader=dataloader, device=device) def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index b563e56e2f..bd69cf3594 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -24,7 +24,12 @@ from torch import nn from torch.utils.data import DataLoader, DistributedSampler, Sampler from pytorch_lightning.lite import LightningLite -from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.lite.wrappers import ( + _LiteDataLoader, + _LiteModule, + _LiteOptimizer, + _replace_dataloader_init_method, +) from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.utilities import DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -192,57 +197,6 @@ def test_setup_dataloaders_with_custom_type(): LiteWithCustomDataLoader().run() -def test_setup_custom_dataloaders(): - """Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader.""" - lite = EmptyLite() - - class CustomDataLoader(DataLoader): - def __init__(self, value: int = 2, *args, **kwargs): - self.value = value - super().__init__(range(value), *args, **kwargs) - - dataloader = CustomDataLoader(2, batch_size=2) - - # single dataloader - lite_dataloader = lite.setup_dataloaders(dataloader) - assert lite_dataloader._dataloader - assert lite_dataloader.value == 2 - batch0 = next(iter(lite_dataloader)) - assert torch.equal(batch0, torch.tensor([0, 1])) - - class CustomDataLoader2(DataLoader): - def __init__(self, range, *args, **kwargs): - self.range = range - super().__init__(range, *args, **kwargs) - - dataloader = CustomDataLoader2(range(2), batch_size=2) - - # single dataloader - lite_dataloader = lite.setup_dataloaders(dataloader) - assert lite_dataloader._dataloader - batch0 = next(iter(lite_dataloader)) - assert torch.equal(batch0, torch.tensor([0, 1])) - - class CustomDataLoader(DataLoader): - def __init__(self, value: int, *args, **kwargs): - super().__init__(range(value), *args, **kwargs) - - class LiteWithCustomDataLoader(LightningLite): - def run(self): - # This doesn't fail as the context manager would save all the arguments provided - # to the dataloaders. - dataloader = CustomDataLoader(2, batch_size=2) - self.setup_dataloaders(dataloader) - - LiteWithCustomDataLoader().run() - - with pytest.raises( - MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance" - ): - dataloader = CustomDataLoader(2, batch_size=2) - lite_dataloader = lite.setup_dataloaders(dataloader) - - def test_setup_dataloaders_twice_fails(): """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" lite = EmptyLite() @@ -490,3 +444,25 @@ def test_deepspeed_multiple_models(): assert self.is_global_zero == (self.local_rank == 0) Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() + + +def test_replace_dataloader_init_method(): + """Test that the context manager enables to save the parameters passed to the DataLoader __init__ method.""" + + class CustomDataLoader(DataLoader): + def __init__(self, extra_argument: int, *args, **kwargs): + super().__init__(*args, **kwargs) + + dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) + lite = EmptyLite() + with pytest.raises(MisconfigurationException, match="extra_argument"): + dataloader = lite.setup_dataloaders(dataloader) + + with _replace_dataloader_init_method(): + dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) + assert dataloader.extra_argument == 1 + dataloader = lite.setup_dataloaders(dataloader) + + dataloader = CustomDataLoader(1, range(1)) + assert dataloader.extra_argument == 1 + dataloader = lite.setup_dataloaders(dataloader)