diff --git a/CHANGELOG.md b/CHANGELOG.md index 59d29e1836..a87b3b95f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) -- +- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) - diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 78d75d9972..9963bf6c85 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -180,7 +180,25 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader: dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode) dl_cls = type(dataloader) - dataloader = dl_cls(**dl_kwargs) + try: + dataloader = dl_cls(**dl_kwargs) + except TypeError as e: + # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass + # `__init__` arguments map to one `DataLoader.__init__` argument + import re + + match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + argument = match.groups()[0] + message = ( + f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument" + f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" + f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." + f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." + ) + raise MisconfigurationException(message) from e return dataloader diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 839b370dbf..e202941cf0 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -5,6 +5,7 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.utilities.data import ( _replace_dataloader_init_method, + _update_dataloader, extract_batch_size, get_len, has_iterable_dataset, @@ -115,6 +116,38 @@ def test_has_len_all_rank(): assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model) +def test_update_dataloader_typerror_custom_exception(): + class BadImpl(DataLoader): + def __init__(self, foo, *args, **kwargs): + self.foo = foo + # positional conflict with `dataset` + super().__init__(foo, *args, **kwargs) + + dataloader = BadImpl([1, 2, 3]) + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): + _update_dataloader(dataloader, dataloader.sampler) + + class BadImpl2(DataLoader): + def __init__(self, randomize, *args, **kwargs): + self.randomize = randomize + # keyword conflict with `shuffle` + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = BadImpl2(False, []) + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): + _update_dataloader(dataloader, dataloader.sampler) + + class GoodImpl(DataLoader): + def __init__(self, randomize, *args, **kwargs): + # fixed implementation, kwargs are filtered + self.randomize = randomize or kwargs.pop("shuffle", False) + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = GoodImpl(False, []) + new_dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert isinstance(new_dataloader, GoodImpl) + + def test_replace_dataloader_init_method(): """Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as attributes."""