Add custom dataloader support with Lite (#10279)
This commit is contained in:
parent
828b5315aa
commit
facaff94b8
|
@ -119,6 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164))
|
||||
* Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162))
|
||||
* Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175))
|
||||
* Make the `_LiteDataLoader` an iterator and add supports for custom dataloader ([#10279](https://github.com/PyTorchLightning/pytorch-lightning/pull/10279))
|
||||
- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
|
||||
- Added `ckpt_path` argument for `Trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
|
||||
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
|
||||
|
|
|
@ -25,7 +25,12 @@ from torch.optim import Optimizer
|
|||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
from pytorch_lightning.lite.wrappers import (
|
||||
_LiteDataLoader,
|
||||
_LiteModule,
|
||||
_LiteOptimizer,
|
||||
_replace_dataloader_init_method,
|
||||
)
|
||||
from pytorch_lightning.plugins import (
|
||||
DDPShardedPlugin,
|
||||
DDPSpawnPlugin,
|
||||
|
@ -183,7 +188,7 @@ class LightningLite(ABC):
|
|||
|
||||
def setup_dataloaders(
|
||||
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
|
||||
) -> Union[DataLoader, List[DataLoader], Iterable]:
|
||||
) -> Union[Iterable, List[Iterable]]:
|
||||
"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each
|
||||
dataloader, call this method individually for each one.
|
||||
|
||||
|
@ -208,7 +213,7 @@ class LightningLite(ABC):
|
|||
|
||||
def _setup_dataloader(
|
||||
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
|
||||
) -> Union[Iterable, DataLoader]:
|
||||
) -> Iterable:
|
||||
"""Setup a single dataloader for accelerated training.
|
||||
|
||||
Args:
|
||||
|
@ -233,17 +238,18 @@ class LightningLite(ABC):
|
|||
)
|
||||
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
|
||||
|
||||
kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
|
||||
device = self.device if move_to_device else None
|
||||
if isinstance(self._strategy, TPUSpawnPlugin):
|
||||
dataloader = DataLoader(**kwargs)
|
||||
else:
|
||||
dataloader = _LiteDataLoader(device=device, **kwargs)
|
||||
|
||||
dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
|
||||
try:
|
||||
dataloader = type(dataloader)(**dataloader_kwargs)
|
||||
except TypeError:
|
||||
dataloader_kwargs.pop("dataset")
|
||||
dataloader = type(dataloader)(**dataloader_kwargs)
|
||||
# add worker_init_fn for correct seeding in worker processes
|
||||
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
|
||||
|
||||
return self._strategy.process_dataloader(dataloader)
|
||||
return _LiteDataLoader(
|
||||
dataloader=self._strategy.process_dataloader(dataloader),
|
||||
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
|
||||
)
|
||||
|
||||
def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
|
||||
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
|
||||
|
@ -400,7 +406,7 @@ class LightningLite(ABC):
|
|||
return run_method(*args, **kwargs)
|
||||
|
||||
def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
with self._strategy.model_sharded_context():
|
||||
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
|
||||
return run_method(*args, **kwargs)
|
||||
|
||||
def _set_plugin_specific_precision_variables(self) -> None:
|
||||
|
|
|
@ -11,7 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Generator, Iterator, Optional, Union
|
||||
import functools
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
@ -100,17 +103,65 @@ class _LiteModule(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class _LiteDataLoader(DataLoader):
|
||||
def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None:
|
||||
"""The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds
|
||||
additional features such as moving the data to the device automatically.
|
||||
def _wrap_init(f: Callable) -> Callable:
|
||||
@functools.wraps(f)
|
||||
def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
|
||||
params = dict(inspect.signature(module._old_init).parameters)
|
||||
params.pop("args")
|
||||
params.pop("kwargs")
|
||||
for init_name, init_arg in zip(params, args):
|
||||
setattr(module, init_name, init_arg)
|
||||
f(module, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/63851681/9201239
|
||||
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
|
||||
subclass_list = []
|
||||
|
||||
def recurse(cl: Type[Any]) -> None:
|
||||
for subclass in cl.__subclasses__():
|
||||
subclass_list.append(subclass)
|
||||
recurse(subclass)
|
||||
|
||||
recurse(cls)
|
||||
return set(subclass_list)
|
||||
|
||||
|
||||
def _enable_class(cls: Type[Any]) -> None:
|
||||
cls._old_init = cls.__init__
|
||||
cls.__init__ = _wrap_init(cls.__init__)
|
||||
|
||||
|
||||
def _disable_class(cls: Type[Any]) -> None:
|
||||
cls.__init__ = cls._old_init
|
||||
del cls._old_init
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _replace_dataloader_init_method() -> Generator:
|
||||
"""This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
|
||||
for subclass in _get_all_subclasses(DataLoader):
|
||||
_enable_class(subclass)
|
||||
yield
|
||||
for subclass in _get_all_subclasses(DataLoader):
|
||||
_disable_class(subclass)
|
||||
|
||||
|
||||
class _LiteDataLoader:
|
||||
def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if
|
||||
the device is specified.
|
||||
|
||||
Args:
|
||||
dataloader: The current dataloader to be used.
|
||||
device: The device to which the data should be moved. By default the device is `None` and no data
|
||||
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
|
||||
**dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts.
|
||||
"""
|
||||
super().__init__(**dl_kwargs)
|
||||
super().__init__()
|
||||
self.__dict__.update(getattr(dataloader, "__dict__", {}))
|
||||
self._dataloader = dataloader
|
||||
self._device = device
|
||||
|
||||
@property
|
||||
|
@ -118,9 +169,9 @@ class _LiteDataLoader(DataLoader):
|
|||
return self._device
|
||||
|
||||
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
|
||||
iterator = super().__iter__()
|
||||
dataloader_iter = iter(self._dataloader)
|
||||
if self._device is None:
|
||||
return iterator
|
||||
return dataloader_iter
|
||||
|
||||
for item in iterator:
|
||||
for item in dataloader_iter:
|
||||
yield move_data_to_device(item, self._device)
|
||||
|
|
|
@ -164,6 +164,57 @@ def test_setup_dataloaders_return_type():
|
|||
assert lite_dataloader1.dataset is dataset1
|
||||
|
||||
|
||||
def test_setup_custom_dataloaders():
|
||||
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
|
||||
lite = EmptyLite()
|
||||
|
||||
class CustomDataLoader(DataLoader):
|
||||
def __init__(self, value: int = 2, *args, **kwargs):
|
||||
self.value = value
|
||||
super().__init__(range(value), *args, **kwargs)
|
||||
|
||||
dataloader = CustomDataLoader(2, batch_size=2)
|
||||
|
||||
# single dataloader
|
||||
lite_dataloader = lite.setup_dataloaders(dataloader)
|
||||
assert lite_dataloader._dataloader
|
||||
assert lite_dataloader.value == 2
|
||||
batch0 = next(iter(lite_dataloader))
|
||||
assert torch.equal(batch0, torch.tensor([0, 1]))
|
||||
|
||||
class CustomDataLoader2(DataLoader):
|
||||
def __init__(self, range, *args, **kwargs):
|
||||
self.range = range
|
||||
super().__init__(range, *args, **kwargs)
|
||||
|
||||
dataloader = CustomDataLoader2(range(2), batch_size=2)
|
||||
|
||||
# single dataloader
|
||||
lite_dataloader = lite.setup_dataloaders(dataloader)
|
||||
assert lite_dataloader._dataloader
|
||||
batch0 = next(iter(lite_dataloader))
|
||||
assert torch.equal(batch0, torch.tensor([0, 1]))
|
||||
|
||||
class CustomDataLoader(DataLoader):
|
||||
def __init__(self, value: int, *args, **kwargs):
|
||||
super().__init__(range(value), *args, **kwargs)
|
||||
|
||||
class LiteWithCustomDataLoader(LightningLite):
|
||||
def run(self):
|
||||
# This doesn't fail as the context manager would save all the arguments provided
|
||||
# to the dataloaders.
|
||||
dataloader = CustomDataLoader(2, batch_size=2)
|
||||
self.setup_dataloaders(dataloader)
|
||||
|
||||
LiteWithCustomDataLoader().run()
|
||||
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
|
||||
):
|
||||
dataloader = CustomDataLoader(2, batch_size=2)
|
||||
lite_dataloader = lite.setup_dataloaders(dataloader)
|
||||
|
||||
|
||||
def test_setup_dataloaders_twice_fails():
|
||||
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
|
||||
lite = EmptyLite()
|
||||
|
|
|
@ -15,6 +15,7 @@ from unittest.mock import ANY, Mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
|
@ -73,8 +74,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
|
|||
sample1 = torch.tensor(1, device=src_device)
|
||||
sample2 = {"data": torch.tensor(2, device=src_device)}
|
||||
sample3 = {"data": torch.tensor(3, device=src_device)}
|
||||
data = [sample0, sample1, sample2, sample3]
|
||||
lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2)
|
||||
dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
|
||||
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device)
|
||||
iterator = iter(lite_dataloader)
|
||||
|
||||
batch0 = next(iterator)
|
||||
|
@ -83,6 +84,15 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
|
|||
batch1 = next(iterator)
|
||||
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
batch1 = next(iterator)
|
||||
|
||||
lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device)
|
||||
iterator = iter(lite_dataloader)
|
||||
|
||||
batch0 = next(iterator)
|
||||
assert batch0 == 0
|
||||
|
||||
|
||||
def test_lite_optimizer_wraps():
|
||||
"""Test that the LiteOptimizer fully wraps the optimizer."""
|
||||
|
|
Loading…
Reference in New Issue