Inspect correct function in wrap_init (#12716)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
c0e6fd6a3b
commit
b8d4b81221
|
@ -321,11 +321,18 @@ def _wrap_init(init: Callable) -> Callable:
|
|||
|
||||
@functools.wraps(init)
|
||||
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
|
||||
params = dict(inspect.signature(obj.__init__).parameters)
|
||||
params.pop("args", None)
|
||||
params.pop("kwargs", None)
|
||||
# We need to inspect `init`, as inspecting `obj.__init__`
|
||||
# can lead to inspecting the wrong function with multiple inheritance
|
||||
params = inspect.signature(init).parameters
|
||||
|
||||
param_names = [
|
||||
param.name
|
||||
for param in params.values()
|
||||
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
|
||||
]
|
||||
|
||||
cls = type(obj)
|
||||
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
|
||||
for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()):
|
||||
if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None:
|
||||
# the class defines a read-only (no setter) property of this name. it's likely that the implementation
|
||||
# will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name`
|
||||
|
|
|
@ -161,19 +161,38 @@ def test_replace_dataloader_init_method():
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
class DataLoaderSubclass2(DataLoaderSubclass1):
|
||||
def __init__(self, attribute1, attribute2, *args, **kwargs):
|
||||
def __init__(self, attribute2, *args, **kwargs):
|
||||
# intentionally not setting this attribute, calling super with different args
|
||||
# self.attribute2 = attribute2
|
||||
super().__init__(attribute1, *args, **kwargs)
|
||||
super().__init__(attribute2 + "-2", *args, **kwargs)
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2)
|
||||
assert dataloader.attribute1 == "attribute1"
|
||||
|
||||
assert dataloader.attribute1 == "attribute1"
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
|
||||
assert dataloader.attribute1 == "attribute1"
|
||||
assert dataloader.attribute2 == "attribute2"
|
||||
dataloader = DataLoaderSubclass2("attribute2", dataset=range(4), batch_size=2)
|
||||
|
||||
assert dataloader.attribute1 == "attribute2-2"
|
||||
assert dataloader.attribute2 == "attribute2"
|
||||
|
||||
# Failing test case from issue 12564
|
||||
class MyBaseDataLoader(DataLoader):
|
||||
pass
|
||||
|
||||
class MyDataLoader(MyBaseDataLoader):
|
||||
def __init__(self, data: torch.Tensor, *args, **kwargs):
|
||||
self.data = data
|
||||
super().__init__(range(data.size(0)), *args, **kwargs)
|
||||
|
||||
data = torch.randn((10, 20))
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = MyDataLoader(data, batch_size=2)
|
||||
|
||||
assert dataloader.data is data
|
||||
assert dataloader.dataset == range(10)
|
||||
|
||||
# `poptorch.DataLoader` uses this pattern, simulate it
|
||||
class PoptorchDataLoader(DataLoader):
|
||||
|
@ -188,9 +207,11 @@ def test_replace_dataloader_init_method():
|
|||
# †his read-only property pattern is fine
|
||||
dataloader = PoptorchDataLoader(123, [1])
|
||||
assert dataloader.options == 123
|
||||
|
||||
# still works with the init replacement
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = PoptorchDataLoader(123, [1])
|
||||
|
||||
assert dataloader.options == 123
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue