Avoid false-positive warnings about method calls on the Fabric-wrapped module (#18819)
This commit is contained in:
parent
e7afe04ee8
commit
97303b0168
|
@ -14,7 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
-
|
||||
- Calling a method other than `forward` that invokes submodules is now an error when the model is wrapped (e.g., with DDP) ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))
|
||||
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
@ -29,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
-
|
||||
- Fixed false-positive warnings about method calls on the Fabric-wrapped module ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))
|
||||
|
||||
|
||||
## [2.1.0] - 2023-10-11
|
||||
|
|
|
@ -12,10 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload
|
||||
|
||||
import torch
|
||||
from lightning_utilities import WarningCache
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
from torch import nn as nn
|
||||
|
@ -30,9 +30,7 @@ from lightning.fabric.utilities.data import _set_sampler_epoch
|
|||
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning.fabric.utilities.types import Optimizable
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
|
||||
warning_cache = WarningCache()
|
||||
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
||||
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
|
||||
|
||||
|
@ -161,25 +159,40 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
|
||||
# have patched to redirect back to `original_module.method_name()`.
|
||||
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
|
||||
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
||||
# Patch the original_module's forward, so we can redirect the arguments back to the real method
|
||||
self._original_module.forward = wrapped_forward
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
return call_forward_module
|
||||
|
||||
def _validate_method_access(self, name: str, attribute: Any) -> None:
|
||||
if (
|
||||
inspect.ismethod(attribute)
|
||||
and inspect.signature(attribute).parameters
|
||||
and self._forward_module != self._original_module
|
||||
):
|
||||
warning_cache.warn(
|
||||
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
|
||||
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
|
||||
" `.backward()`. You should pass your inputs through"
|
||||
f" `{type(self._original_module).__name__}.forward()`.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
|
||||
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
|
||||
registering forward hooks on all submodules."""
|
||||
module_called = False
|
||||
|
||||
def hook(*_: Any, **__: Any) -> None:
|
||||
nonlocal module_called
|
||||
module_called = True
|
||||
|
||||
@wraps(method)
|
||||
def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
|
||||
handles = []
|
||||
for module in self._original_module.modules():
|
||||
handles.append(module.register_forward_hook(hook))
|
||||
|
||||
output = method(*args, **kwargs)
|
||||
|
||||
if module_called:
|
||||
raise RuntimeError(
|
||||
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
|
||||
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
|
||||
" `.backward()`. You should pass your inputs through `forward()`.",
|
||||
)
|
||||
for handle in handles:
|
||||
handle.remove()
|
||||
return output
|
||||
|
||||
return _wrapped_method
|
||||
|
||||
def __getattr__(self, item: Any) -> Any:
|
||||
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
|
||||
|
@ -194,7 +207,9 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
|
||||
original_module = super().__getattr__("_original_module")
|
||||
attr = getattr(original_module, item)
|
||||
self._validate_method_access(item, attr)
|
||||
|
||||
if inspect.ismethod(attr) and self._forward_module != self._original_module:
|
||||
attr = self._wrap_method_with_module_call_tracker(attr, item)
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
|
|
|
@ -26,9 +26,7 @@ from lightning.fabric.wrappers import (
|
|||
_FabricOptimizer,
|
||||
_unwrap_objects,
|
||||
is_wrapped,
|
||||
warning_cache,
|
||||
)
|
||||
from lightning_utilities.test.warning import no_warning_call
|
||||
from torch.utils.data import BatchSampler, DistributedSampler
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
|
@ -79,12 +77,24 @@ def test_fabric_module_method_lookup():
|
|||
"""Test that access to methods warns about improper use when a wrapper from a strategy is involved."""
|
||||
|
||||
class OriginalModule(torch.nn.Module):
|
||||
def method_no_args(self):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.submodule = torch.nn.Linear(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
def method_without_module_invocation(self):
|
||||
return 100
|
||||
|
||||
def method_with_args(self, arg, kwarg=1):
|
||||
def method_with_submodule_invocation(self):
|
||||
self.submodule(torch.rand(2, 2))
|
||||
return 101
|
||||
|
||||
def method_with_self_invocation(self):
|
||||
self(None)
|
||||
return 102
|
||||
|
||||
class ModuleWrapper(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
|
@ -93,21 +103,21 @@ def test_fabric_module_method_lookup():
|
|||
# Regular case: forward_module == original_module -> no warnings
|
||||
original_module = OriginalModule()
|
||||
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
|
||||
warning_cache.clear()
|
||||
with no_warning_call(UserWarning):
|
||||
assert fabric_module.method_with_args(0) == 101
|
||||
assert not warning_cache
|
||||
assert fabric_module.method_without_module_invocation() == 100
|
||||
|
||||
# Special case: original module wrapped by forward module: -> warn if method accepts args
|
||||
original_module = OriginalModule()
|
||||
wrapped_module = ModuleWrapper(original_module)
|
||||
fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module)
|
||||
warning_cache.clear()
|
||||
with no_warning_call(UserWarning):
|
||||
assert fabric_module.method_no_args() == 100
|
||||
with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method_with_args\(\)` from"):
|
||||
assert fabric_module.method_with_args(0) == 101
|
||||
warning_cache.clear()
|
||||
assert fabric_module.method_without_module_invocation() == 100
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
|
||||
):
|
||||
assert fabric_module.method_with_submodule_invocation() == 101
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_self_invocation\(\)` from"
|
||||
):
|
||||
assert fabric_module.method_with_self_invocation() == 102
|
||||
|
||||
|
||||
def test_fabric_module_setattr():
|
||||
|
@ -555,7 +565,7 @@ def test_step_method_redirection():
|
|||
fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module)
|
||||
|
||||
# Regular methods on the original_module are visible and identical on the fabric_module ...
|
||||
assert fabric_module.normal_method == original_module.normal_method
|
||||
assert fabric_module.normal_method.__wrapped__ == original_module.normal_method
|
||||
|
||||
# ... but special methods like training_step get redirected to the forward_module
|
||||
assert fabric_module.training_step.__name__ == "call_forward_module"
|
||||
|
|
|
@ -1488,7 +1488,7 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
|
|||
callback = ModelCheckpoint(dirpath=first, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)
|
||||
trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs)
|
||||
trainer.fit(model)
|
||||
assert os.listdir(first) == ["epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"]
|
||||
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"}
|
||||
|
||||
# Continue training from checkpoint
|
||||
callback = ModelCheckpoint(dirpath=new_dirpath, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)
|
||||
|
|
Loading…
Reference in New Issue