Pass-through setattr for FabricModule (#17731)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Boon 2023-06-13 03:33:51 +08:00 committed by GitHub
parent d75bfe5c38
commit 377bfd2768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 0 deletions

View File

@ -100,6 +100,7 @@ ignore-init-module-imports = true
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
"S607", # todo: Starting a process with a partial executable path
"RET504", # todo:Unnecessary variable assignment before `return` statement
"RET503",
]
"tests/**" = [
"S101", # Use of `assert` detected

View File

@ -101,6 +101,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
self._forward_module = forward_module
self._original_module = original_module or forward_module
self._precision = precision
self._fabric_module_initialized = True
@property
def module(self) -> nn.Module:
@ -185,6 +186,31 @@ class _FabricModule(_DeviceDtypeModuleMixin):
self._validate_method_access(item, attr)
return attr
def __setattr__(self, name: str, value: Any) -> None:
if not getattr(self, "_fabric_module_initialized", False):
super().__setattr__(name, value)
return
# Get the _original_module attribute
original_module = self._original_module
original_has_attr = hasattr(original_module, name)
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
# Can't use self.__getattr__ because it would pass through to the original module
fabric_has_attr = name in self.__dict__
if not (original_has_attr or fabric_has_attr):
setattr(original_module, name, value)
return
# The original module can also inherit from _DeviceDtypeModuleMixin,
# in this case, both the Fabric module and original module have attributes like _dtype
# set attribute on both
if original_has_attr:
setattr(original_module, name, value)
if fabric_has_attr:
super().__setattr__(name, value)
class _FabricDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:

View File

@ -104,6 +104,56 @@ def test_fabric_module_method_lookup():
warning_cache.clear()
def test_fabric_module_setattr():
"""Test that setattr sets attributes on the original module."""
class OriginalModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(2, 3)
self.attribute = 1
self._x = None
@property
def x(self):
return self._x
@x.setter
def x(self, value):
self._x = value
original_module = OriginalModule()
class ModuleWrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.wrapped = original_module
wrapped_module = ModuleWrapper()
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
# Check new attribute is set on original_module
fabric_module.new_attribute = 100
assert original_module.new_attribute == 100
# Modify existing attribute on original_module
fabric_module.attribute = 101
assert original_module.attribute == 101
# Check setattr of original_module
fabric_module.x = 102
assert original_module.x == 102
# Check set submodule
assert not hasattr(original_module, "linear")
linear = torch.nn.Linear(2, 2)
fabric_module.linear = linear
assert hasattr(original_module, "linear")
assert isinstance(original_module.linear, torch.nn.Module)
assert linear in fabric_module.modules()
assert linear in original_module.modules()
def test_fabric_module_state_dict_access():
"""Test that state_dict access passes through to the original module."""