diff --git a/CHANGELOG.md b/CHANGELOG.md index 03aede03cb..3cd7495990 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014)) +- Fixed `materialize_module` setting a module's child recursively ([#12870](https://github.com/PyTorchLightning/pytorch-lightning/pull/12870)) + + - Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897)) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index d14f111e87..a5edcfb300 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -186,7 +186,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module: if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): materialize_module(child) else: - setattr(child, name, materialize_fn()) + setattr(root_module, name, materialize_fn()) return root_module diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index c55f04591f..34a2c3d04e 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest from torch import nn from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module +from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -24,7 +26,7 @@ class MLP(nn.Module): self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)]) -class BoringModel(LightningModule): +class SimpleBoringModel(LightningModule): def __init__(self, num_layers: int): super().__init__() self.save_hyperparameters() @@ -48,7 +50,7 @@ def test_init_meta_context(): assert not is_on_meta_device(mlp) assert not is_on_meta_device(nn.Module()) - model = BoringModel(4) + model = SimpleBoringModel(4) assert model.layer[0].weight.device.type == "meta" materialize_module(model) assert model.layer[0].weight.device.type == "cpu" @@ -68,3 +70,15 @@ def test_init_meta_context(): m = nn.Linear(in_features=1, out_features=1) assert m.weight.device.type == "cpu" + + +@RunIf(min_torch="1.10.0", standalone=True) +def test_materialize_module_recursive_child(): + """Test materialize_module doesn't set a child recursively to a model instantiated within init_meta_context.""" + with init_meta_context(): + model = BoringModel() + + materialize_module(model) + + with pytest.raises(AttributeError, match="'Linear' object has no attribute 'layer'"): + model.layer.layer