From facaff94b8aa33e9b5138dc70e897df21594482b Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 1 Nov 2021 18:33:13 +0000 Subject: [PATCH] Add custom dataloader support with Lite (#10279) --- CHANGELOG.md | 1 + pytorch_lightning/lite/lite.py | 32 ++++++++------ pytorch_lightning/lite/wrappers.py | 71 +++++++++++++++++++++++++----- tests/lite/test_lite.py | 51 +++++++++++++++++++++ tests/lite/test_wrappers.py | 14 +++++- 5 files changed, 144 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60702a2866..161c5b9309 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,6 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164)) * Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162)) * Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175)) + * Make the `_LiteDataLoader` an iterator and add supports for custom dataloader ([#10279](https://github.com/PyTorchLightning/pytorch-lightning/pull/10279)) - Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170)) - Added `ckpt_path` argument for `Trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061)) - Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222)) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 7d0ff6a436..2e6f10d356 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -25,7 +25,12 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.lite.wrappers import ( + _LiteDataLoader, + _LiteModule, + _LiteOptimizer, + _replace_dataloader_init_method, +) from pytorch_lightning.plugins import ( DDPShardedPlugin, DDPSpawnPlugin, @@ -183,7 +188,7 @@ class LightningLite(ABC): def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[DataLoader, List[DataLoader], Iterable]: + ) -> Union[Iterable, List[Iterable]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -208,7 +213,7 @@ class LightningLite(ABC): def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[Iterable, DataLoader]: + ) -> Iterable: """Setup a single dataloader for accelerated training. Args: @@ -233,17 +238,18 @@ class LightningLite(ABC): ) sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) - kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) - device = self.device if move_to_device else None - if isinstance(self._strategy, TPUSpawnPlugin): - dataloader = DataLoader(**kwargs) - else: - dataloader = _LiteDataLoader(device=device, **kwargs) - + dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) + try: + dataloader = type(dataloader)(**dataloader_kwargs) + except TypeError: + dataloader_kwargs.pop("dataset") + dataloader = type(dataloader)(**dataloader_kwargs) # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) - - return self._strategy.process_dataloader(dataloader) + return _LiteDataLoader( + dataloader=self._strategy.process_dataloader(dataloader), + device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None, + ) def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. @@ -400,7 +406,7 @@ class LightningLite(ABC): return run_method(*args, **kwargs) def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - with self._strategy.model_sharded_context(): + with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): return run_method(*args, **kwargs) def _set_plugin_specific_precision_variables(self) -> None: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 3dd387319a..d9acba70bc 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Generator, Iterator, Optional, Union +import functools +import inspect +from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn @@ -100,17 +103,65 @@ class _LiteModule(nn.Module): return output -class _LiteDataLoader(DataLoader): - def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None: - """The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds - additional features such as moving the data to the device automatically. +def _wrap_init(f: Callable) -> Callable: + @functools.wraps(f) + def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None: + params = dict(inspect.signature(module._old_init).parameters) + params.pop("args") + params.pop("kwargs") + for init_name, init_arg in zip(params, args): + setattr(module, init_name, init_arg) + f(module, *args, **kwargs) + + return wrapper + + +# https://stackoverflow.com/a/63851681/9201239 +def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: + subclass_list = [] + + def recurse(cl: Type[Any]) -> None: + for subclass in cl.__subclasses__(): + subclass_list.append(subclass) + recurse(subclass) + + recurse(cls) + return set(subclass_list) + + +def _enable_class(cls: Type[Any]) -> None: + cls._old_init = cls.__init__ + cls.__init__ = _wrap_init(cls.__init__) + + +def _disable_class(cls: Type[Any]) -> None: + cls.__init__ = cls._old_init + del cls._old_init + + +@contextmanager +def _replace_dataloader_init_method() -> Generator: + """This context manager is used to support custom :class:`~torch.utils.data.DataLoader.""" + for subclass in _get_all_subclasses(DataLoader): + _enable_class(subclass) + yield + for subclass in _get_all_subclasses(DataLoader): + _disable_class(subclass) + + +class _LiteDataLoader: + def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None: + """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if + the device is specified. Args: + dataloader: The current dataloader to be used. device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). - **dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts. """ - super().__init__(**dl_kwargs) + super().__init__() + self.__dict__.update(getattr(dataloader, "__dict__", {})) + self._dataloader = dataloader self._device = device @property @@ -118,9 +169,9 @@ class _LiteDataLoader(DataLoader): return self._device def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: - iterator = super().__iter__() + dataloader_iter = iter(self._dataloader) if self._device is None: - return iterator + return dataloader_iter - for item in iterator: + for item in dataloader_iter: yield move_data_to_device(item, self._device) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 916e0aa542..8eac30f9cf 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -164,6 +164,57 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 +def test_setup_custom_dataloaders(): + """Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader.""" + lite = EmptyLite() + + class CustomDataLoader(DataLoader): + def __init__(self, value: int = 2, *args, **kwargs): + self.value = value + super().__init__(range(value), *args, **kwargs) + + dataloader = CustomDataLoader(2, batch_size=2) + + # single dataloader + lite_dataloader = lite.setup_dataloaders(dataloader) + assert lite_dataloader._dataloader + assert lite_dataloader.value == 2 + batch0 = next(iter(lite_dataloader)) + assert torch.equal(batch0, torch.tensor([0, 1])) + + class CustomDataLoader2(DataLoader): + def __init__(self, range, *args, **kwargs): + self.range = range + super().__init__(range, *args, **kwargs) + + dataloader = CustomDataLoader2(range(2), batch_size=2) + + # single dataloader + lite_dataloader = lite.setup_dataloaders(dataloader) + assert lite_dataloader._dataloader + batch0 = next(iter(lite_dataloader)) + assert torch.equal(batch0, torch.tensor([0, 1])) + + class CustomDataLoader(DataLoader): + def __init__(self, value: int, *args, **kwargs): + super().__init__(range(value), *args, **kwargs) + + class LiteWithCustomDataLoader(LightningLite): + def run(self): + # This doesn't fail as the context manager would save all the arguments provided + # to the dataloaders. + dataloader = CustomDataLoader(2, batch_size=2) + self.setup_dataloaders(dataloader) + + LiteWithCustomDataLoader().run() + + with pytest.raises( + MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance" + ): + dataloader = CustomDataLoader(2, batch_size=2) + lite_dataloader = lite.setup_dataloaders(dataloader) + + def test_setup_dataloaders_twice_fails(): """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" lite = EmptyLite() diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 3e2e9ac7a9..4dd7b4a890 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -15,6 +15,7 @@ from unittest.mock import ANY, Mock import pytest import torch +from torch.utils.data.dataloader import DataLoader from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -73,8 +74,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device): sample1 = torch.tensor(1, device=src_device) sample2 = {"data": torch.tensor(2, device=src_device)} sample3 = {"data": torch.tensor(3, device=src_device)} - data = [sample0, sample1, sample2, sample3] - lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2) + dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2) + lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device) iterator = iter(lite_dataloader) batch0 = next(iterator) @@ -83,6 +84,15 @@ def test_lite_dataloader_device_placement(src_device, dest_device): batch1 = next(iterator) assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) + with pytest.raises(StopIteration): + batch1 = next(iterator) + + lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device) + iterator = iter(lite_dataloader) + + batch0 = next(iterator) + assert batch0 == 0 + def test_lite_optimizer_wraps(): """Test that the LiteOptimizer fully wraps the optimizer."""