Improve error message on `TypeError` during `DataLoader` reconstruction (#10719)

This commit is contained in:
Carlos Mocholí 2021-11-24 22:51:11 +01:00 committed by GitHub
parent 0066ff0129
commit f8b2d5b128
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 2 deletions

View File

@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
-
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
-

View File

@ -180,7 +180,25 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
dataloader = dl_cls(**dl_kwargs)
try:
dataloader = dl_cls(**dl_kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
import re
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise
argument = match.groups()[0]
message = (
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
)
raise MisconfigurationException(message) from e
return dataloader

View File

@ -5,6 +5,7 @@ from torch.utils.data.dataloader import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import (
_replace_dataloader_init_method,
_update_dataloader,
extract_batch_size,
get_len,
has_iterable_dataset,
@ -115,6 +116,38 @@ def test_has_len_all_rank():
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)
def test_update_dataloader_typerror_custom_exception():
class BadImpl(DataLoader):
def __init__(self, foo, *args, **kwargs):
self.foo = foo
# positional conflict with `dataset`
super().__init__(foo, *args, **kwargs)
dataloader = BadImpl([1, 2, 3])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
_update_dataloader(dataloader, dataloader.sampler)
class BadImpl2(DataLoader):
def __init__(self, randomize, *args, **kwargs):
self.randomize = randomize
# keyword conflict with `shuffle`
super().__init__(*args, shuffle=randomize, **kwargs)
dataloader = BadImpl2(False, [])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"):
_update_dataloader(dataloader, dataloader.sampler)
class GoodImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
# fixed implementation, kwargs are filtered
self.randomize = randomize or kwargs.pop("shuffle", False)
super().__init__(*args, shuffle=randomize, **kwargs)
dataloader = GoodImpl(False, [])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, GoodImpl)
def test_replace_dataloader_init_method():
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
sets them as attributes."""