Disallow batch sampler with multiple IPU devices (#13854)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
otaj 2022-07-27 11:50:43 +02:00 committed by GitHub
parent 56b1e1aaaa
commit 4c7b9f0b11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 31 deletions

View File

@ -62,7 +62,8 @@ Currently there are some known limitations that are being addressed in the near
Please see the `MNIST example <https://github.com/Lightning-AI/lightning/blob/master/examples/pl_ipu/mnist_sample.py>`__ which displays most of the limitations and how to overcome them till they are resolved.
* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this
* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result
* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code
* Clipping gradients is not supported
* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this.
* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result.
* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code.
* Clipping gradients is not supported.
* It is not possible to use :class:`torch.utils.data.BatchSampler` in your dataloaders if you are using multiple IPUs.

View File

@ -167,7 +167,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated Habana Accelerator's `auto_device_count`, `is_available` & `get_device_name` methods based on the latest torch habana package ([#13423](https://github.com/PyTorchLightning/pytorch-lightning/pull/13423))
-
- Disallowed using `BatchSampler` when running on multiple IPUs ([#13854](https://github.com/PyTorchLightning/pytorch-lightning/pull/13854))
### Deprecated

View File

@ -162,6 +162,8 @@ class IPUStrategy(ParallelStrategy):
if self.lightning_module.trainer.enable_validation:
model = poptorch.inferenceModel(model=model, options=inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
if self.lightning_module.trainer.num_sanity_val_steps > 0:
self.poptorch_models[RunningStage.SANITY_CHECKING] = model
elif trainer_fn == TrainerFn.VALIDATING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
@ -228,7 +230,9 @@ class IPUStrategy(ParallelStrategy):
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
return dataloader
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
dataloader, sampler, mode, self.replication_factor > 1
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
return dataloader

View File

@ -186,7 +186,7 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
dl_cls = type(dataloader)
try:
dataloader = dl_cls(*dl_args, **dl_kwargs)
@ -212,7 +212,10 @@ def _update_dataloader(
def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
dataloader: DataLoader,
sampler: Optional[Sampler],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
@ -264,7 +267,7 @@ def _get_dataloader_init_args_and_kwargs(
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode))
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode, disallow_batch_sampler))
required_args = {
p.name
@ -309,7 +312,10 @@ def _get_dataloader_init_args_and_kwargs(
def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
dataloader: DataLoader,
sampler: Optional[Sampler],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
re-instantiation.
@ -321,27 +327,39 @@ def _dataloader_init_kwargs_resolve_sampler(
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
if fault_tolerant_mode.is_automatic:
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)
if batch_sampler is not None:
if disallow_batch_sampler:
# Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__
if not (
type(batch_sampler) is BatchSampler
and batch_sampler.sampler == sampler
and dataloader.batch_size == batch_sampler.batch_size
):
raise MisconfigurationException(
"It is not possible to have a batch sampler in your dataloader, "
"when running on multiple IPU devices."
)
elif type(batch_sampler) is not BatchSampler or is_predicting:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}
if fault_tolerant_mode.is_automatic:
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)
return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}
if fault_tolerant_mode.is_automatic:
fast_forward_sampler = sampler = FastForwardSampler(sampler)

View File

@ -619,7 +619,11 @@ def test_poptorch_models_at_different_stages(tmpdir):
trainer.optimizers = model.configure_optimizers()[0]
trainer.state.fn = TrainerFn.FITTING
trainer.strategy.setup(trainer)
assert list(trainer.strategy.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING]
assert list(trainer.strategy.poptorch_models) == [
RunningStage.TRAINING,
RunningStage.VALIDATING,
RunningStage.SANITY_CHECKING,
]
for fn, stage in (
(TrainerFn.VALIDATING, RunningStage.VALIDATING),

View File

@ -3,12 +3,13 @@ from dataclasses import dataclass
import pytest
import torch
from torch import Tensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import (
_dataloader_init_kwargs_resolve_sampler,
_get_dataloader_init_args_and_kwargs,
_replace_dataloader_init_method,
_update_dataloader,
@ -331,6 +332,23 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c
assert getattr(dataloader, key) == value
def test_dataloader_disallow_batch_sampler():
dataset = RandomDataset(5, 100)
dataloader = DataLoader(dataset, batch_size=10)
# This should not raise
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)
dataset = RandomDataset(5, 100)
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
# this should raise - using batch sampler, that was not automatically instantiated by DataLoader
with pytest.raises(MisconfigurationException, match="when running on multiple IPU devices"):
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""