Pass-through setattr for FabricModule (#17731)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
d75bfe5c38
commit
377bfd2768
|
@ -100,6 +100,7 @@ ignore-init-module-imports = true
|
|||
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
|
||||
"S607", # todo: Starting a process with a partial executable path
|
||||
"RET504", # todo:Unnecessary variable assignment before `return` statement
|
||||
"RET503",
|
||||
]
|
||||
"tests/**" = [
|
||||
"S101", # Use of `assert` detected
|
||||
|
|
|
@ -101,6 +101,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
self._forward_module = forward_module
|
||||
self._original_module = original_module or forward_module
|
||||
self._precision = precision
|
||||
self._fabric_module_initialized = True
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
|
@ -185,6 +186,31 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
self._validate_method_access(item, attr)
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if not getattr(self, "_fabric_module_initialized", False):
|
||||
super().__setattr__(name, value)
|
||||
return
|
||||
|
||||
# Get the _original_module attribute
|
||||
original_module = self._original_module
|
||||
original_has_attr = hasattr(original_module, name)
|
||||
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
|
||||
# Can't use self.__getattr__ because it would pass through to the original module
|
||||
fabric_has_attr = name in self.__dict__
|
||||
|
||||
if not (original_has_attr or fabric_has_attr):
|
||||
setattr(original_module, name, value)
|
||||
return
|
||||
|
||||
# The original module can also inherit from _DeviceDtypeModuleMixin,
|
||||
# in this case, both the Fabric module and original module have attributes like _dtype
|
||||
# set attribute on both
|
||||
if original_has_attr:
|
||||
setattr(original_module, name, value)
|
||||
|
||||
if fabric_has_attr:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
class _FabricDataLoader:
|
||||
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
|
||||
|
|
|
@ -104,6 +104,56 @@ def test_fabric_module_method_lookup():
|
|||
warning_cache.clear()
|
||||
|
||||
|
||||
def test_fabric_module_setattr():
|
||||
"""Test that setattr sets attributes on the original module."""
|
||||
|
||||
class OriginalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.Linear(2, 3)
|
||||
self.attribute = 1
|
||||
self._x = None
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self._x
|
||||
|
||||
@x.setter
|
||||
def x(self, value):
|
||||
self._x = value
|
||||
|
||||
original_module = OriginalModule()
|
||||
|
||||
class ModuleWrapper(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.wrapped = original_module
|
||||
|
||||
wrapped_module = ModuleWrapper()
|
||||
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
|
||||
|
||||
# Check new attribute is set on original_module
|
||||
fabric_module.new_attribute = 100
|
||||
assert original_module.new_attribute == 100
|
||||
|
||||
# Modify existing attribute on original_module
|
||||
fabric_module.attribute = 101
|
||||
assert original_module.attribute == 101
|
||||
|
||||
# Check setattr of original_module
|
||||
fabric_module.x = 102
|
||||
assert original_module.x == 102
|
||||
|
||||
# Check set submodule
|
||||
assert not hasattr(original_module, "linear")
|
||||
linear = torch.nn.Linear(2, 2)
|
||||
fabric_module.linear = linear
|
||||
assert hasattr(original_module, "linear")
|
||||
assert isinstance(original_module.linear, torch.nn.Module)
|
||||
assert linear in fabric_module.modules()
|
||||
assert linear in original_module.modules()
|
||||
|
||||
|
||||
def test_fabric_module_state_dict_access():
|
||||
"""Test that state_dict access passes through to the original module."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue