Revert part of #10279 (#10376)

This commit is contained in:
Carlos Mocholí 2021-11-08 12:28:58 +01:00 committed by GitHub
parent 89e1360e75
commit 613aa09514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 62 deletions

View File

@ -238,18 +238,15 @@ class LightningLite(ABC):
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
try:
dataloader = type(dataloader)(**dataloader_kwargs)
except TypeError:
dataloader_kwargs.pop("dataset")
dataloader = type(dataloader)(**dataloader_kwargs)
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
dataloader = TrainerDataLoadingMixin._update_dataloader(dataloader, sampler)
# add worker_init_fn for correct seeding in worker processes
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
return _LiteDataLoader(
dataloader=self._strategy.process_dataloader(dataloader),
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
)
dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
return _LiteDataLoader(dataloader=dataloader, device=device)
def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.

View File

@ -24,7 +24,12 @@ from torch import nn
from torch.utils.data import DataLoader, DistributedSampler, Sampler
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -192,57 +197,6 @@ def test_setup_dataloaders_with_custom_type():
LiteWithCustomDataLoader().run()
def test_setup_custom_dataloaders():
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
lite = EmptyLite()
class CustomDataLoader(DataLoader):
def __init__(self, value: int = 2, *args, **kwargs):
self.value = value
super().__init__(range(value), *args, **kwargs)
dataloader = CustomDataLoader(2, batch_size=2)
# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
assert lite_dataloader.value == 2
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))
class CustomDataLoader2(DataLoader):
def __init__(self, range, *args, **kwargs):
self.range = range
super().__init__(range, *args, **kwargs)
dataloader = CustomDataLoader2(range(2), batch_size=2)
# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))
class CustomDataLoader(DataLoader):
def __init__(self, value: int, *args, **kwargs):
super().__init__(range(value), *args, **kwargs)
class LiteWithCustomDataLoader(LightningLite):
def run(self):
# This doesn't fail as the context manager would save all the arguments provided
# to the dataloaders.
dataloader = CustomDataLoader(2, batch_size=2)
self.setup_dataloaders(dataloader)
LiteWithCustomDataLoader().run()
with pytest.raises(
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
):
dataloader = CustomDataLoader(2, batch_size=2)
lite_dataloader = lite.setup_dataloaders(dataloader)
def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
lite = EmptyLite()
@ -490,3 +444,25 @@ def test_deepspeed_multiple_models():
assert self.is_global_zero == (self.local_rank == 0)
Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
def test_replace_dataloader_init_method():
"""Test that the context manager enables to save the parameters passed to the DataLoader __init__ method."""
class CustomDataLoader(DataLoader):
def __init__(self, extra_argument: int, *args, **kwargs):
super().__init__(*args, **kwargs)
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
lite = EmptyLite()
with pytest.raises(MisconfigurationException, match="extra_argument"):
dataloader = lite.setup_dataloaders(dataloader)
with _replace_dataloader_init_method():
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)
dataloader = CustomDataLoader(1, range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)