Add custom dataloader support with Lite (#10279)

This commit is contained in:
thomas chaton 2021-11-01 18:33:13 +00:00 committed by GitHub
parent 828b5315aa
commit facaff94b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 144 additions and 25 deletions

View File

@ -119,6 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164))
* Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162))
* Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175))
* Make the `_LiteDataLoader` an iterator and add supports for custom dataloader ([#10279](https://github.com/PyTorchLightning/pytorch-lightning/pull/10279))
- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
- Added `ckpt_path` argument for `Trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))

View File

@ -25,7 +25,12 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.plugins import (
DDPShardedPlugin,
DDPSpawnPlugin,
@ -183,7 +188,7 @@ class LightningLite(ABC):
def setup_dataloaders(
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> Union[DataLoader, List[DataLoader], Iterable]:
) -> Union[Iterable, List[Iterable]]:
"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each
dataloader, call this method individually for each one.
@ -208,7 +213,7 @@ class LightningLite(ABC):
def _setup_dataloader(
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> Union[Iterable, DataLoader]:
) -> Iterable:
"""Setup a single dataloader for accelerated training.
Args:
@ -233,17 +238,18 @@ class LightningLite(ABC):
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
device = self.device if move_to_device else None
if isinstance(self._strategy, TPUSpawnPlugin):
dataloader = DataLoader(**kwargs)
else:
dataloader = _LiteDataLoader(device=device, **kwargs)
dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
try:
dataloader = type(dataloader)(**dataloader_kwargs)
except TypeError:
dataloader_kwargs.pop("dataset")
dataloader = type(dataloader)(**dataloader_kwargs)
# add worker_init_fn for correct seeding in worker processes
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
return self._strategy.process_dataloader(dataloader)
return _LiteDataLoader(
dataloader=self._strategy.process_dataloader(dataloader),
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
)
def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
@ -400,7 +406,7 @@ class LightningLite(ABC):
return run_method(*args, **kwargs)
def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
with self._strategy.model_sharded_context():
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
return run_method(*args, **kwargs)
def _set_plugin_specific_precision_variables(self) -> None:

View File

@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generator, Iterator, Optional, Union
import functools
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type, Union
import torch
from torch import nn as nn
@ -100,17 +103,65 @@ class _LiteModule(nn.Module):
return output
class _LiteDataLoader(DataLoader):
def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None:
"""The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds
additional features such as moving the data to the device automatically.
def _wrap_init(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
params = dict(inspect.signature(module._old_init).parameters)
params.pop("args")
params.pop("kwargs")
for init_name, init_arg in zip(params, args):
setattr(module, init_name, init_arg)
f(module, *args, **kwargs)
return wrapper
# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
subclass_list = []
def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)
recurse(cls)
return set(subclass_list)
def _enable_class(cls: Type[Any]) -> None:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init(cls.__init__)
def _disable_class(cls: Type[Any]) -> None:
cls.__init__ = cls._old_init
del cls._old_init
@contextmanager
def _replace_dataloader_init_method() -> Generator:
"""This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
for subclass in _get_all_subclasses(DataLoader):
_enable_class(subclass)
yield
for subclass in _get_all_subclasses(DataLoader):
_disable_class(subclass)
class _LiteDataLoader:
def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None:
"""The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if
the device is specified.
Args:
dataloader: The current dataloader to be used.
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
**dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts.
"""
super().__init__(**dl_kwargs)
super().__init__()
self.__dict__.update(getattr(dataloader, "__dict__", {}))
self._dataloader = dataloader
self._device = device
@property
@ -118,9 +169,9 @@ class _LiteDataLoader(DataLoader):
return self._device
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
iterator = super().__iter__()
dataloader_iter = iter(self._dataloader)
if self._device is None:
return iterator
return dataloader_iter
for item in iterator:
for item in dataloader_iter:
yield move_data_to_device(item, self._device)

View File

@ -164,6 +164,57 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1
def test_setup_custom_dataloaders():
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
lite = EmptyLite()
class CustomDataLoader(DataLoader):
def __init__(self, value: int = 2, *args, **kwargs):
self.value = value
super().__init__(range(value), *args, **kwargs)
dataloader = CustomDataLoader(2, batch_size=2)
# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
assert lite_dataloader.value == 2
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))
class CustomDataLoader2(DataLoader):
def __init__(self, range, *args, **kwargs):
self.range = range
super().__init__(range, *args, **kwargs)
dataloader = CustomDataLoader2(range(2), batch_size=2)
# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))
class CustomDataLoader(DataLoader):
def __init__(self, value: int, *args, **kwargs):
super().__init__(range(value), *args, **kwargs)
class LiteWithCustomDataLoader(LightningLite):
def run(self):
# This doesn't fail as the context manager would save all the arguments provided
# to the dataloaders.
dataloader = CustomDataLoader(2, batch_size=2)
self.setup_dataloaders(dataloader)
LiteWithCustomDataLoader().run()
with pytest.raises(
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
):
dataloader = CustomDataLoader(2, batch_size=2)
lite_dataloader = lite.setup_dataloaders(dataloader)
def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
lite = EmptyLite()

View File

@ -15,6 +15,7 @@ from unittest.mock import ANY, Mock
import pytest
import torch
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
@ -73,8 +74,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
sample1 = torch.tensor(1, device=src_device)
sample2 = {"data": torch.tensor(2, device=src_device)}
sample3 = {"data": torch.tensor(3, device=src_device)}
data = [sample0, sample1, sample2, sample3]
lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2)
dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device)
iterator = iter(lite_dataloader)
batch0 = next(iterator)
@ -83,6 +84,15 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
batch1 = next(iterator)
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
with pytest.raises(StopIteration):
batch1 = next(iterator)
lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device)
iterator = iter(lite_dataloader)
batch0 = next(iterator)
assert batch0 == 0
def test_lite_optimizer_wraps():
"""Test that the LiteOptimizer fully wraps the optimizer."""