diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 5577ed654e..5d54c8e53f 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -321,11 +321,18 @@ def _wrap_init(init: Callable) -> Callable: @functools.wraps(init) def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: - params = dict(inspect.signature(obj.__init__).parameters) - params.pop("args", None) - params.pop("kwargs", None) + # We need to inspect `init`, as inspecting `obj.__init__` + # can lead to inspecting the wrong function with multiple inheritance + params = inspect.signature(init).parameters + + param_names = [ + param.name + for param in params.values() + if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + ] + cls = type(obj) - for arg_name, arg_value in chain(zip(params, args), kwargs.items()): + for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()): if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None: # the class defines a read-only (no setter) property of this name. it's likely that the implementation # will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name` diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index f64ba026b0..f056fe99cd 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -161,19 +161,38 @@ def test_replace_dataloader_init_method(): super().__init__(*args, **kwargs) class DataLoaderSubclass2(DataLoaderSubclass1): - def __init__(self, attribute1, attribute2, *args, **kwargs): + def __init__(self, attribute2, *args, **kwargs): # intentionally not setting this attribute, calling super with different args # self.attribute2 = attribute2 - super().__init__(attribute1, *args, **kwargs) + super().__init__(attribute2 + "-2", *args, **kwargs) with _replace_dataloader_init_method(): dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2) - assert dataloader.attribute1 == "attribute1" + + assert dataloader.attribute1 == "attribute1" with _replace_dataloader_init_method(): - dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2) - assert dataloader.attribute1 == "attribute1" - assert dataloader.attribute2 == "attribute2" + dataloader = DataLoaderSubclass2("attribute2", dataset=range(4), batch_size=2) + + assert dataloader.attribute1 == "attribute2-2" + assert dataloader.attribute2 == "attribute2" + + # Failing test case from issue 12564 + class MyBaseDataLoader(DataLoader): + pass + + class MyDataLoader(MyBaseDataLoader): + def __init__(self, data: torch.Tensor, *args, **kwargs): + self.data = data + super().__init__(range(data.size(0)), *args, **kwargs) + + data = torch.randn((10, 20)) + + with _replace_dataloader_init_method(): + dataloader = MyDataLoader(data, batch_size=2) + + assert dataloader.data is data + assert dataloader.dataset == range(10) # `poptorch.DataLoader` uses this pattern, simulate it class PoptorchDataLoader(DataLoader): @@ -188,9 +207,11 @@ def test_replace_dataloader_init_method(): # †his read-only property pattern is fine dataloader = PoptorchDataLoader(123, [1]) assert dataloader.options == 123 + # still works with the init replacement with _replace_dataloader_init_method(): dataloader = PoptorchDataLoader(123, [1]) + assert dataloader.options == 123