Disallow batch sampler with multiple IPU devices (#13854)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
56b1e1aaaa
commit
4c7b9f0b11
|
@ -62,7 +62,8 @@ Currently there are some known limitations that are being addressed in the near
|
|||
|
||||
Please see the `MNIST example <https://github.com/Lightning-AI/lightning/blob/master/examples/pl_ipu/mnist_sample.py>`__ which displays most of the limitations and how to overcome them till they are resolved.
|
||||
|
||||
* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this
|
||||
* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result
|
||||
* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code
|
||||
* Clipping gradients is not supported
|
||||
* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this.
|
||||
* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result.
|
||||
* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code.
|
||||
* Clipping gradients is not supported.
|
||||
* It is not possible to use :class:`torch.utils.data.BatchSampler` in your dataloaders if you are using multiple IPUs.
|
||||
|
|
|
@ -167,7 +167,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Updated Habana Accelerator's `auto_device_count`, `is_available` & `get_device_name` methods based on the latest torch habana package ([#13423](https://github.com/PyTorchLightning/pytorch-lightning/pull/13423))
|
||||
|
||||
|
||||
-
|
||||
- Disallowed using `BatchSampler` when running on multiple IPUs ([#13854](https://github.com/PyTorchLightning/pytorch-lightning/pull/13854))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
|
|
@ -162,6 +162,8 @@ class IPUStrategy(ParallelStrategy):
|
|||
if self.lightning_module.trainer.enable_validation:
|
||||
model = poptorch.inferenceModel(model=model, options=inference_opts)
|
||||
self.poptorch_models[RunningStage.VALIDATING] = model
|
||||
if self.lightning_module.trainer.num_sanity_val_steps > 0:
|
||||
self.poptorch_models[RunningStage.SANITY_CHECKING] = model
|
||||
elif trainer_fn == TrainerFn.VALIDATING:
|
||||
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
|
||||
self.poptorch_models[RunningStage.VALIDATING] = model
|
||||
|
@ -228,7 +230,9 @@ class IPUStrategy(ParallelStrategy):
|
|||
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
|
||||
return dataloader
|
||||
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler)
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
|
||||
dataloader, sampler, mode, self.replication_factor > 1
|
||||
)
|
||||
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
|
||||
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
|
||||
return dataloader
|
||||
|
|
|
@ -186,7 +186,7 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
|
|||
def _update_dataloader(
|
||||
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
|
||||
) -> DataLoader:
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode)
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
|
||||
dl_cls = type(dataloader)
|
||||
try:
|
||||
dataloader = dl_cls(*dl_args, **dl_kwargs)
|
||||
|
@ -212,7 +212,10 @@ def _update_dataloader(
|
|||
|
||||
|
||||
def _get_dataloader_init_args_and_kwargs(
|
||||
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
mode: Optional[RunningStage] = None,
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
|
||||
|
@ -264,7 +267,7 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
dl_kwargs["batch_sampler"] = None
|
||||
dl_kwargs["sampler"] = None
|
||||
else:
|
||||
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode))
|
||||
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode, disallow_batch_sampler))
|
||||
|
||||
required_args = {
|
||||
p.name
|
||||
|
@ -309,7 +312,10 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
|
||||
|
||||
def _dataloader_init_kwargs_resolve_sampler(
|
||||
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
mode: Optional[RunningStage] = None,
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
|
||||
re-instantiation.
|
||||
|
@ -321,27 +327,39 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
||||
batch_sampler = getattr(dataloader, "batch_sampler")
|
||||
is_predicting = mode == RunningStage.PREDICTING
|
||||
# checking the batch sampler type is different than PyTorch default.
|
||||
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
|
||||
batch_sampler = type(batch_sampler)(
|
||||
sampler,
|
||||
batch_size=batch_sampler.batch_size,
|
||||
drop_last=(False if is_predicting else batch_sampler.drop_last),
|
||||
)
|
||||
if is_predicting:
|
||||
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
|
||||
|
||||
if fault_tolerant_mode.is_automatic:
|
||||
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
|
||||
fast_forward_sampler.setup(dataloader_batch_size=1)
|
||||
if batch_sampler is not None:
|
||||
if disallow_batch_sampler:
|
||||
# Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__
|
||||
if not (
|
||||
type(batch_sampler) is BatchSampler
|
||||
and batch_sampler.sampler == sampler
|
||||
and dataloader.batch_size == batch_sampler.batch_size
|
||||
):
|
||||
raise MisconfigurationException(
|
||||
"It is not possible to have a batch sampler in your dataloader, "
|
||||
"when running on multiple IPU devices."
|
||||
)
|
||||
elif type(batch_sampler) is not BatchSampler or is_predicting:
|
||||
batch_sampler = type(batch_sampler)(
|
||||
sampler,
|
||||
batch_size=batch_sampler.batch_size,
|
||||
drop_last=(False if is_predicting else batch_sampler.drop_last),
|
||||
)
|
||||
if is_predicting:
|
||||
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
|
||||
|
||||
return {
|
||||
"sampler": None,
|
||||
"shuffle": False,
|
||||
"batch_sampler": batch_sampler,
|
||||
"batch_size": 1,
|
||||
"drop_last": False,
|
||||
}
|
||||
if fault_tolerant_mode.is_automatic:
|
||||
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
|
||||
fast_forward_sampler.setup(dataloader_batch_size=1)
|
||||
|
||||
return {
|
||||
"sampler": None,
|
||||
"shuffle": False,
|
||||
"batch_sampler": batch_sampler,
|
||||
"batch_size": 1,
|
||||
"drop_last": False,
|
||||
}
|
||||
|
||||
if fault_tolerant_mode.is_automatic:
|
||||
fast_forward_sampler = sampler = FastForwardSampler(sampler)
|
||||
|
|
|
@ -619,7 +619,11 @@ def test_poptorch_models_at_different_stages(tmpdir):
|
|||
trainer.optimizers = model.configure_optimizers()[0]
|
||||
trainer.state.fn = TrainerFn.FITTING
|
||||
trainer.strategy.setup(trainer)
|
||||
assert list(trainer.strategy.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING]
|
||||
assert list(trainer.strategy.poptorch_models) == [
|
||||
RunningStage.TRAINING,
|
||||
RunningStage.VALIDATING,
|
||||
RunningStage.SANITY_CHECKING,
|
||||
]
|
||||
|
||||
for fn, stage in (
|
||||
(TrainerFn.VALIDATING, RunningStage.VALIDATING),
|
||||
|
|
|
@ -3,12 +3,13 @@ from dataclasses import dataclass
|
|||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_dataloader_init_kwargs_resolve_sampler,
|
||||
_get_dataloader_init_args_and_kwargs,
|
||||
_replace_dataloader_init_method,
|
||||
_update_dataloader,
|
||||
|
@ -331,6 +332,23 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c
|
|||
assert getattr(dataloader, key) == value
|
||||
|
||||
|
||||
def test_dataloader_disallow_batch_sampler():
|
||||
dataset = RandomDataset(5, 100)
|
||||
dataloader = DataLoader(dataset, batch_size=10)
|
||||
|
||||
# This should not raise
|
||||
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)
|
||||
|
||||
dataset = RandomDataset(5, 100)
|
||||
sampler = SequentialSampler(dataset)
|
||||
batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False)
|
||||
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
|
||||
|
||||
# this should raise - using batch sampler, that was not automatically instantiated by DataLoader
|
||||
with pytest.raises(MisconfigurationException, match="when running on multiple IPU devices"):
|
||||
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
|
||||
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
|
||||
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
|
||||
|
|
Loading…
Reference in New Issue