parent
89e1360e75
commit
613aa09514
|
@ -238,18 +238,15 @@ class LightningLite(ABC):
|
|||
)
|
||||
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
|
||||
|
||||
dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
|
||||
try:
|
||||
dataloader = type(dataloader)(**dataloader_kwargs)
|
||||
except TypeError:
|
||||
dataloader_kwargs.pop("dataset")
|
||||
dataloader = type(dataloader)(**dataloader_kwargs)
|
||||
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
|
||||
dataloader = TrainerDataLoadingMixin._update_dataloader(dataloader, sampler)
|
||||
|
||||
# add worker_init_fn for correct seeding in worker processes
|
||||
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
|
||||
return _LiteDataLoader(
|
||||
dataloader=self._strategy.process_dataloader(dataloader),
|
||||
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
|
||||
)
|
||||
|
||||
dataloader = self._strategy.process_dataloader(dataloader)
|
||||
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
|
||||
return _LiteDataLoader(dataloader=dataloader, device=device)
|
||||
|
||||
def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
|
||||
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
|
||||
|
|
|
@ -24,7 +24,12 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader, DistributedSampler, Sampler
|
||||
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
from pytorch_lightning.lite.wrappers import (
|
||||
_LiteDataLoader,
|
||||
_LiteModule,
|
||||
_LiteOptimizer,
|
||||
_replace_dataloader_init_method,
|
||||
)
|
||||
from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
|
||||
from pytorch_lightning.utilities import DistributedType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -192,57 +197,6 @@ def test_setup_dataloaders_with_custom_type():
|
|||
LiteWithCustomDataLoader().run()
|
||||
|
||||
|
||||
def test_setup_custom_dataloaders():
|
||||
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
|
||||
lite = EmptyLite()
|
||||
|
||||
class CustomDataLoader(DataLoader):
|
||||
def __init__(self, value: int = 2, *args, **kwargs):
|
||||
self.value = value
|
||||
super().__init__(range(value), *args, **kwargs)
|
||||
|
||||
dataloader = CustomDataLoader(2, batch_size=2)
|
||||
|
||||
# single dataloader
|
||||
lite_dataloader = lite.setup_dataloaders(dataloader)
|
||||
assert lite_dataloader._dataloader
|
||||
assert lite_dataloader.value == 2
|
||||
batch0 = next(iter(lite_dataloader))
|
||||
assert torch.equal(batch0, torch.tensor([0, 1]))
|
||||
|
||||
class CustomDataLoader2(DataLoader):
|
||||
def __init__(self, range, *args, **kwargs):
|
||||
self.range = range
|
||||
super().__init__(range, *args, **kwargs)
|
||||
|
||||
dataloader = CustomDataLoader2(range(2), batch_size=2)
|
||||
|
||||
# single dataloader
|
||||
lite_dataloader = lite.setup_dataloaders(dataloader)
|
||||
assert lite_dataloader._dataloader
|
||||
batch0 = next(iter(lite_dataloader))
|
||||
assert torch.equal(batch0, torch.tensor([0, 1]))
|
||||
|
||||
class CustomDataLoader(DataLoader):
|
||||
def __init__(self, value: int, *args, **kwargs):
|
||||
super().__init__(range(value), *args, **kwargs)
|
||||
|
||||
class LiteWithCustomDataLoader(LightningLite):
|
||||
def run(self):
|
||||
# This doesn't fail as the context manager would save all the arguments provided
|
||||
# to the dataloaders.
|
||||
dataloader = CustomDataLoader(2, batch_size=2)
|
||||
self.setup_dataloaders(dataloader)
|
||||
|
||||
LiteWithCustomDataLoader().run()
|
||||
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
|
||||
):
|
||||
dataloader = CustomDataLoader(2, batch_size=2)
|
||||
lite_dataloader = lite.setup_dataloaders(dataloader)
|
||||
|
||||
|
||||
def test_setup_dataloaders_twice_fails():
|
||||
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
|
||||
lite = EmptyLite()
|
||||
|
@ -490,3 +444,25 @@ def test_deepspeed_multiple_models():
|
|||
assert self.is_global_zero == (self.local_rank == 0)
|
||||
|
||||
Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
|
||||
|
||||
|
||||
def test_replace_dataloader_init_method():
|
||||
"""Test that the context manager enables to save the parameters passed to the DataLoader __init__ method."""
|
||||
|
||||
class CustomDataLoader(DataLoader):
|
||||
def __init__(self, extra_argument: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
|
||||
lite = EmptyLite()
|
||||
with pytest.raises(MisconfigurationException, match="extra_argument"):
|
||||
dataloader = lite.setup_dataloaders(dataloader)
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
|
||||
assert dataloader.extra_argument == 1
|
||||
dataloader = lite.setup_dataloaders(dataloader)
|
||||
|
||||
dataloader = CustomDataLoader(1, range(1))
|
||||
assert dataloader.extra_argument == 1
|
||||
dataloader = lite.setup_dataloaders(dataloader)
|
||||
|
|
Loading…
Reference in New Issue