diff --git a/pyproject.toml b/pyproject.toml index b9f84c9c7a..9f21631bec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ ignore-init-module-imports = true "S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell` "S607", # todo: Starting a process with a partial executable path "RET504", # todo:Unnecessary variable assignment before `return` statement + "RET503", ] "tests/**" = [ "S101", # Use of `assert` detected diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 6246912fef..5375ac9c1a 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -101,6 +101,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): self._forward_module = forward_module self._original_module = original_module or forward_module self._precision = precision + self._fabric_module_initialized = True @property def module(self) -> nn.Module: @@ -185,6 +186,31 @@ class _FabricModule(_DeviceDtypeModuleMixin): self._validate_method_access(item, attr) return attr + def __setattr__(self, name: str, value: Any) -> None: + if not getattr(self, "_fabric_module_initialized", False): + super().__setattr__(name, value) + return + + # Get the _original_module attribute + original_module = self._original_module + original_has_attr = hasattr(original_module, name) + # Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules + # Can't use self.__getattr__ because it would pass through to the original module + fabric_has_attr = name in self.__dict__ + + if not (original_has_attr or fabric_has_attr): + setattr(original_module, name, value) + return + + # The original module can also inherit from _DeviceDtypeModuleMixin, + # in this case, both the Fabric module and original module have attributes like _dtype + # set attribute on both + if original_has_attr: + setattr(original_module, name, value) + + if fabric_has_attr: + super().__setattr__(name, value) + class _FabricDataLoader: def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 9db28585a5..ec8bf01bee 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -104,6 +104,56 @@ def test_fabric_module_method_lookup(): warning_cache.clear() +def test_fabric_module_setattr(): + """Test that setattr sets attributes on the original module.""" + + class OriginalModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(2, 3) + self.attribute = 1 + self._x = None + + @property + def x(self): + return self._x + + @x.setter + def x(self, value): + self._x = value + + original_module = OriginalModule() + + class ModuleWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self.wrapped = original_module + + wrapped_module = ModuleWrapper() + fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module) + + # Check new attribute is set on original_module + fabric_module.new_attribute = 100 + assert original_module.new_attribute == 100 + + # Modify existing attribute on original_module + fabric_module.attribute = 101 + assert original_module.attribute == 101 + + # Check setattr of original_module + fabric_module.x = 102 + assert original_module.x == 102 + + # Check set submodule + assert not hasattr(original_module, "linear") + linear = torch.nn.Linear(2, 2) + fabric_module.linear = linear + assert hasattr(original_module, "linear") + assert isinstance(original_module.linear, torch.nn.Module) + assert linear in fabric_module.modules() + assert linear in original_module.modules() + + def test_fabric_module_state_dict_access(): """Test that state_dict access passes through to the original module."""