Inspect correct function in wrap_init (#12716)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
otaj 2022-04-15 13:58:28 +02:00 committed by GitHub
parent c0e6fd6a3b
commit b8d4b81221
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 10 deletions

View File

@ -321,11 +321,18 @@ def _wrap_init(init: Callable) -> Callable:
@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
params = dict(inspect.signature(obj.__init__).parameters)
params.pop("args", None)
params.pop("kwargs", None)
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
params = inspect.signature(init).parameters
param_names = [
param.name
for param in params.values()
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
]
cls = type(obj)
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()):
if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None:
# the class defines a read-only (no setter) property of this name. it's likely that the implementation
# will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name`

View File

@ -161,19 +161,38 @@ def test_replace_dataloader_init_method():
super().__init__(*args, **kwargs)
class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute1, attribute2, *args, **kwargs):
def __init__(self, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
super().__init__(attribute1, *args, **kwargs)
super().__init__(attribute2 + "-2", *args, **kwargs)
with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute1 == "attribute1"
with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute2 == "attribute2"
dataloader = DataLoaderSubclass2("attribute2", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute2-2"
assert dataloader.attribute2 == "attribute2"
# Failing test case from issue 12564
class MyBaseDataLoader(DataLoader):
pass
class MyDataLoader(MyBaseDataLoader):
def __init__(self, data: torch.Tensor, *args, **kwargs):
self.data = data
super().__init__(range(data.size(0)), *args, **kwargs)
data = torch.randn((10, 20))
with _replace_dataloader_init_method():
dataloader = MyDataLoader(data, batch_size=2)
assert dataloader.data is data
assert dataloader.dataset == range(10)
# `poptorch.DataLoader` uses this pattern, simulate it
class PoptorchDataLoader(DataLoader):
@ -188,9 +207,11 @@ def test_replace_dataloader_init_method():
# †his read-only property pattern is fine
dataloader = PoptorchDataLoader(123, [1])
assert dataloader.options == 123
# still works with the init replacement
with _replace_dataloader_init_method():
dataloader = PoptorchDataLoader(123, [1])
assert dataloader.options == 123