[IPU] Support manually instantiating the `poptorch.DataLoader` (#12116)

This commit is contained in:
Carlos Mocholí 2022-02-28 10:36:26 +01:00 committed by GitHub
parent b29b07e978
commit 8fd17f2edf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 4 deletions

View File

@ -113,6 +113,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Bagua` training strategy ([#11146](https://github.com/PyTorchLightning/pytorch-lightning/pull/11146))
- Added support for manually returning a `poptorch.DataLoader` in a `*_dataloader` hook ([#12116](https://github.com/PyTorchLightning/pytorch-lightning/pull/12116))
- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))

View File

@ -214,12 +214,13 @@ class IPUStrategy(ParallelStrategy):
def _convert_to_poptorch_loader(
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
) -> "poptorch.DataLoader":
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler)
# Override to drop last uneven batch, as IPUs does not support uneven inputs.
dl_kwargs["drop_last"] = True
if isinstance(dataloader, poptorch.DataLoader):
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
return dataloader
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(**dl_kwargs, options=opts)
dataloader = poptorch.DataLoader(opts, **dl_kwargs)
return dataloader
def _handle_gradient_accumulation_steps(self) -> None:

View File

@ -329,7 +329,13 @@ def _wrap_init(init: Callable) -> Callable:
params = dict(inspect.signature(obj.__init__).parameters)
params.pop("args", None)
params.pop("kwargs", None)
cls = type(obj)
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None:
# the class defines a read-only (no setter) property of this name. it's likely that the implementation
# will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name`
# property so we are fine skipping in that case
continue
setattr(obj, arg_name, arg_value)
init(obj, *args, **kwargs)

View File

@ -344,6 +344,38 @@ def test_autoreport(tmpdir):
assert os.path.isfile(autoreport_path + "profile.pop")
@RunIf(ipu=True)
def test_manual_poptorch_dataloader(tmpdir):
model_options = poptorch.Options()
class IPUTestModel(IPUModel):
def train_dataloader(self):
dataloader = super().train_dataloader()
# save to instance to compare the reference later
self.poptorch_dataloader = poptorch.DataLoader(model_options, dataloader.dataset, drop_last=True)
return self.poptorch_dataloader
model = IPUTestModel()
other_options = poptorch.Options()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
accelerator="ipu",
devices=2,
strategy=IPUStrategy(training_opts=other_options),
)
trainer.fit(model)
assert isinstance(trainer.strategy, IPUStrategy)
assert trainer.strategy.training_opts is other_options
dataloader = trainer.train_dataloader.loaders
assert dataloader is model.poptorch_dataloader # exact object, was not recreated
# dataloader uses the options in the model, not the strategy
assert dataloader.options is model_options
assert dataloader.options is not other_options
assert dataloader.drop_last # was kept
@RunIf(ipu=True)
def test_manual_poptorch_opts(tmpdir):
"""Ensure if the user passes manual poptorch Options, we run with the correct object."""

View File

@ -175,6 +175,24 @@ def test_replace_dataloader_init_method():
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute2 == "attribute2"
# `poptorch.DataLoader` uses this pattern, simulate it
class PoptorchDataLoader(DataLoader):
def __init__(self, options, *args, **kwargs):
super().__init__(*args, **kwargs)
self._options = options
@property
def options(self):
return self._options
# †his read-only property pattern is fine
dataloader = PoptorchDataLoader(123, [1])
assert dataloader.options == 123
# still works with the init replacement
with _replace_dataloader_init_method():
dataloader = PoptorchDataLoader(123, [1])
assert dataloader.options == 123
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):