From 4c7b9f0b1114a87ab9a0b2fd7b8952d0b22d5c40 Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Wed, 27 Jul 2022 11:50:43 +0200 Subject: [PATCH] Disallow batch sampler with multiple IPU devices (#13854) Co-authored-by: Rohit Gupta --- .../source-pytorch/accelerators/ipu_basic.rst | 9 +-- src/pytorch_lightning/CHANGELOG.md | 2 +- src/pytorch_lightning/strategies/ipu.py | 6 +- src/pytorch_lightning/utilities/data.py | 64 ++++++++++++------- tests/tests_pytorch/accelerators/test_ipu.py | 6 +- tests/tests_pytorch/utilities/test_data.py | 20 +++++- 6 files changed, 76 insertions(+), 31 deletions(-) diff --git a/docs/source-pytorch/accelerators/ipu_basic.rst b/docs/source-pytorch/accelerators/ipu_basic.rst index 6ff0cb701d..99a5c69a10 100644 --- a/docs/source-pytorch/accelerators/ipu_basic.rst +++ b/docs/source-pytorch/accelerators/ipu_basic.rst @@ -62,7 +62,8 @@ Currently there are some known limitations that are being addressed in the near Please see the `MNIST example `__ which displays most of the limitations and how to overcome them till they are resolved. -* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this -* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result -* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code -* Clipping gradients is not supported +* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this. +* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result. +* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code. +* Clipping gradients is not supported. +* It is not possible to use :class:`torch.utils.data.BatchSampler` in your dataloaders if you are using multiple IPUs. diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 4af493b7f3..f8341248b2 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -167,7 +167,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated Habana Accelerator's `auto_device_count`, `is_available` & `get_device_name` methods based on the latest torch habana package ([#13423](https://github.com/PyTorchLightning/pytorch-lightning/pull/13423)) -- +- Disallowed using `BatchSampler` when running on multiple IPUs ([#13854](https://github.com/PyTorchLightning/pytorch-lightning/pull/13854)) ### Deprecated diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 5413756c15..001ad77fbb 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -162,6 +162,8 @@ class IPUStrategy(ParallelStrategy): if self.lightning_module.trainer.enable_validation: model = poptorch.inferenceModel(model=model, options=inference_opts) self.poptorch_models[RunningStage.VALIDATING] = model + if self.lightning_module.trainer.num_sanity_val_steps > 0: + self.poptorch_models[RunningStage.SANITY_CHECKING] = model elif trainer_fn == TrainerFn.VALIDATING: model = poptorch.inferenceModel(model=model, options=self.inference_opts) self.poptorch_models[RunningStage.VALIDATING] = model @@ -228,7 +230,9 @@ class IPUStrategy(ParallelStrategy): # the user is returning the `poptorch.DataLoader` directly, don't change anything. return dataloader - dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler) + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs( + dataloader, sampler, mode, self.replication_factor > 1 + ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) return dataloader diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2de82ceff0..e60c56f6c7 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -186,7 +186,7 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: def _update_dataloader( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None ) -> DataLoader: - dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode) + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) dl_cls = type(dataloader) try: dataloader = dl_cls(*dl_args, **dl_kwargs) @@ -212,7 +212,10 @@ def _update_dataloader( def _get_dataloader_init_args_and_kwargs( - dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None + dataloader: DataLoader, + sampler: Optional[Sampler], + mode: Optional[RunningStage] = None, + disallow_batch_sampler: bool = False, ) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -264,7 +267,7 @@ def _get_dataloader_init_args_and_kwargs( dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None else: - dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)) + dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode, disallow_batch_sampler)) required_args = { p.name @@ -309,7 +312,10 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( - dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None + dataloader: DataLoader, + sampler: Optional[Sampler], + mode: Optional[RunningStage] = None, + disallow_batch_sampler: bool = False, ) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-instantiation. @@ -321,27 +327,39 @@ def _dataloader_init_kwargs_resolve_sampler( fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING - # checking the batch sampler type is different than PyTorch default. - if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting): - batch_sampler = type(batch_sampler)( - sampler, - batch_size=batch_sampler.batch_size, - drop_last=(False if is_predicting else batch_sampler.drop_last), - ) - if is_predicting: - batch_sampler = IndexBatchSamplerWrapper(batch_sampler) - if fault_tolerant_mode.is_automatic: - fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) - fast_forward_sampler.setup(dataloader_batch_size=1) + if batch_sampler is not None: + if disallow_batch_sampler: + # Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__ + if not ( + type(batch_sampler) is BatchSampler + and batch_sampler.sampler == sampler + and dataloader.batch_size == batch_sampler.batch_size + ): + raise MisconfigurationException( + "It is not possible to have a batch sampler in your dataloader, " + "when running on multiple IPU devices." + ) + elif type(batch_sampler) is not BatchSampler or is_predicting: + batch_sampler = type(batch_sampler)( + sampler, + batch_size=batch_sampler.batch_size, + drop_last=(False if is_predicting else batch_sampler.drop_last), + ) + if is_predicting: + batch_sampler = IndexBatchSamplerWrapper(batch_sampler) - return { - "sampler": None, - "shuffle": False, - "batch_sampler": batch_sampler, - "batch_size": 1, - "drop_last": False, - } + if fault_tolerant_mode.is_automatic: + fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) + fast_forward_sampler.setup(dataloader_batch_size=1) + + return { + "sampler": None, + "shuffle": False, + "batch_sampler": batch_sampler, + "batch_size": 1, + "drop_last": False, + } if fault_tolerant_mode.is_automatic: fast_forward_sampler = sampler = FastForwardSampler(sampler) diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 97f374a40d..589ec7b29d 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -619,7 +619,11 @@ def test_poptorch_models_at_different_stages(tmpdir): trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING trainer.strategy.setup(trainer) - assert list(trainer.strategy.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING] + assert list(trainer.strategy.poptorch_models) == [ + RunningStage.TRAINING, + RunningStage.VALIDATING, + RunningStage.SANITY_CHECKING, + ] for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 7b1e596d50..5f66d802ea 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -3,12 +3,13 @@ from dataclasses import dataclass import pytest import torch from torch import Tensor -from torch.utils.data.dataloader import DataLoader +from torch.utils.data import BatchSampler, DataLoader, SequentialSampler from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( + _dataloader_init_kwargs_resolve_sampler, _get_dataloader_init_args_and_kwargs, _replace_dataloader_init_method, _update_dataloader, @@ -331,6 +332,23 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c assert getattr(dataloader, key) == value +def test_dataloader_disallow_batch_sampler(): + dataset = RandomDataset(5, 100) + dataloader = DataLoader(dataset, batch_size=10) + + # This should not raise + _dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True) + + dataset = RandomDataset(5, 100) + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False) + dataloader = DataLoader(dataset, batch_sampler=batch_sampler) + + # this should raise - using batch sampler, that was not automatically instantiated by DataLoader + with pytest.raises(MisconfigurationException, match="when running on multiple IPU devices"): + _dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True) + + @pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING]) def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): """Test that DataLoader kwargs are not replaced when using Iterable Dataset."""