Improve error message on `TypeError` during `DataLoader` reconstruction (#10719)
This commit is contained in:
parent
0066ff0129
commit
f8b2d5b128
|
@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
|
||||
|
||||
|
||||
-
|
||||
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -180,7 +180,25 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
|
|||
def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
|
||||
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
|
||||
dl_cls = type(dataloader)
|
||||
dataloader = dl_cls(**dl_kwargs)
|
||||
try:
|
||||
dataloader = dl_cls(**dl_kwargs)
|
||||
except TypeError as e:
|
||||
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
|
||||
# `__init__` arguments map to one `DataLoader.__init__` argument
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
argument = match.groups()[0]
|
||||
message = (
|
||||
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
|
||||
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
|
||||
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
|
||||
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
|
||||
)
|
||||
raise MisconfigurationException(message) from e
|
||||
return dataloader
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_replace_dataloader_init_method,
|
||||
_update_dataloader,
|
||||
extract_batch_size,
|
||||
get_len,
|
||||
has_iterable_dataset,
|
||||
|
@ -115,6 +116,38 @@ def test_has_len_all_rank():
|
|||
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)
|
||||
|
||||
|
||||
def test_update_dataloader_typerror_custom_exception():
|
||||
class BadImpl(DataLoader):
|
||||
def __init__(self, foo, *args, **kwargs):
|
||||
self.foo = foo
|
||||
# positional conflict with `dataset`
|
||||
super().__init__(foo, *args, **kwargs)
|
||||
|
||||
dataloader = BadImpl([1, 2, 3])
|
||||
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
|
||||
_update_dataloader(dataloader, dataloader.sampler)
|
||||
|
||||
class BadImpl2(DataLoader):
|
||||
def __init__(self, randomize, *args, **kwargs):
|
||||
self.randomize = randomize
|
||||
# keyword conflict with `shuffle`
|
||||
super().__init__(*args, shuffle=randomize, **kwargs)
|
||||
|
||||
dataloader = BadImpl2(False, [])
|
||||
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"):
|
||||
_update_dataloader(dataloader, dataloader.sampler)
|
||||
|
||||
class GoodImpl(DataLoader):
|
||||
def __init__(self, randomize, *args, **kwargs):
|
||||
# fixed implementation, kwargs are filtered
|
||||
self.randomize = randomize or kwargs.pop("shuffle", False)
|
||||
super().__init__(*args, shuffle=randomize, **kwargs)
|
||||
|
||||
dataloader = GoodImpl(False, [])
|
||||
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
|
||||
assert isinstance(new_dataloader, GoodImpl)
|
||||
|
||||
|
||||
def test_replace_dataloader_init_method():
|
||||
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
|
||||
sets them as attributes."""
|
||||
|
|
Loading…
Reference in New Issue