diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 58af69db93..7fbc5d25de 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -14,7 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Calling a method other than `forward` that invokes submodules is now an error when the model is wrapped (e.g., with DDP) ([#18819](https://github.com/Lightning-AI/lightning/pull/18819)) + ### Deprecated @@ -29,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed false-positive warnings about method calls on the Fabric-wrapped module ([#18819](https://github.com/Lightning-AI/lightning/pull/18819)) ## [2.1.0] - 2023-10-11 diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 1a40f44730..6dda42e987 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from functools import wraps from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload import torch -from lightning_utilities import WarningCache from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch import nn as nn @@ -30,9 +30,7 @@ from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.types import Optimizable -from lightning.fabric.utilities.warnings import PossibleUserWarning -warning_cache = WarningCache() T_destination = TypeVar("T_destination", bound=Dict[str, Any]) _LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step") @@ -161,25 +159,40 @@ class _FabricModule(_DeviceDtypeModuleMixin): # We expect that the `forward_module` will eventually call `original_module.forward`, which we # have patched to redirect back to `original_module.method_name()`. def call_forward_module(*args: Any, **kwargs: Any) -> Any: - # Patch the original_module's forward so we can redirect the arguments back to the real method + # Patch the original_module's forward, so we can redirect the arguments back to the real method self._original_module.forward = wrapped_forward return self.forward(*args, **kwargs) return call_forward_module - def _validate_method_access(self, name: str, attribute: Any) -> None: - if ( - inspect.ismethod(attribute) - and inspect.signature(attribute).parameters - and self._forward_module != self._original_module - ): - warning_cache.warn( - f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the" - " model. This will bypass the wrapper from the strategy and result in incorrect behavior in" - " `.backward()`. You should pass your inputs through" - f" `{type(self._original_module).__name__}.forward()`.", - category=PossibleUserWarning, - ) + def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable: + """Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by + registering forward hooks on all submodules.""" + module_called = False + + def hook(*_: Any, **__: Any) -> None: + nonlocal module_called + module_called = True + + @wraps(method) + def _wrapped_method(*args: Any, **kwargs: Any) -> Any: + handles = [] + for module in self._original_module.modules(): + handles.append(module.register_forward_hook(hook)) + + output = method(*args, **kwargs) + + if module_called: + raise RuntimeError( + f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the" + " model. This will bypass the wrapper from the strategy and result in incorrect behavior in" + " `.backward()`. You should pass your inputs through `forward()`.", + ) + for handle in handles: + handle.remove() + return output + + return _wrapped_method def __getattr__(self, item: Any) -> Any: if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module: @@ -194,7 +207,9 @@ class _FabricModule(_DeviceDtypeModuleMixin): # If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module original_module = super().__getattr__("_original_module") attr = getattr(original_module, item) - self._validate_method_access(item, attr) + + if inspect.ismethod(attr) and self._forward_module != self._original_module: + attr = self._wrap_method_with_module_call_tracker(attr, item) return attr def __setattr__(self, name: str, value: Any) -> None: diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index dbfaf0f907..820d9032f5 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -26,9 +26,7 @@ from lightning.fabric.wrappers import ( _FabricOptimizer, _unwrap_objects, is_wrapped, - warning_cache, ) -from lightning_utilities.test.warning import no_warning_call from torch.utils.data import BatchSampler, DistributedSampler from torch.utils.data.dataloader import DataLoader @@ -79,12 +77,24 @@ def test_fabric_module_method_lookup(): """Test that access to methods warns about improper use when a wrapper from a strategy is involved.""" class OriginalModule(torch.nn.Module): - def method_no_args(self): + def __init__(self): + super().__init__() + self.submodule = torch.nn.Linear(2, 3) + + def forward(self, x): + return x + + def method_without_module_invocation(self): return 100 - def method_with_args(self, arg, kwarg=1): + def method_with_submodule_invocation(self): + self.submodule(torch.rand(2, 2)) return 101 + def method_with_self_invocation(self): + self(None) + return 102 + class ModuleWrapper(torch.nn.Module): def __init__(self, module): super().__init__() @@ -93,21 +103,21 @@ def test_fabric_module_method_lookup(): # Regular case: forward_module == original_module -> no warnings original_module = OriginalModule() fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module) - warning_cache.clear() - with no_warning_call(UserWarning): - assert fabric_module.method_with_args(0) == 101 - assert not warning_cache + assert fabric_module.method_without_module_invocation() == 100 # Special case: original module wrapped by forward module: -> warn if method accepts args original_module = OriginalModule() wrapped_module = ModuleWrapper(original_module) fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module) - warning_cache.clear() - with no_warning_call(UserWarning): - assert fabric_module.method_no_args() == 100 - with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method_with_args\(\)` from"): - assert fabric_module.method_with_args(0) == 101 - warning_cache.clear() + assert fabric_module.method_without_module_invocation() == 100 + with pytest.raises( + RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from" + ): + assert fabric_module.method_with_submodule_invocation() == 101 + with pytest.raises( + RuntimeError, match=r"You are calling the method `OriginalModule.method_with_self_invocation\(\)` from" + ): + assert fabric_module.method_with_self_invocation() == 102 def test_fabric_module_setattr(): @@ -555,7 +565,7 @@ def test_step_method_redirection(): fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module) # Regular methods on the original_module are visible and identical on the fabric_module ... - assert fabric_module.normal_method == original_module.normal_method + assert fabric_module.normal_method.__wrapped__ == original_module.normal_method # ... but special methods like training_step get redirected to the forward_module assert fabric_module.training_step.__name__ == "call_forward_module" diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index cb5b2f2740..113728167c 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1488,7 +1488,7 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path): callback = ModelCheckpoint(dirpath=first, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2) trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs) trainer.fit(model) - assert os.listdir(first) == ["epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"] + assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"} # Continue training from checkpoint callback = ModelCheckpoint(dirpath=new_dirpath, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)