Fix `materialize_module` recursively setting its child module (#12870)

* Don't set materialized child to child's child
* Update CHANGELOG

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Akihiro Nitta 2022-05-14 03:05:38 +09:00 committed by GitHub
parent db7b0361a5
commit 03039a236e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 3 deletions

View File

@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
- Fixed `materialize_module` setting a module's child recursively ([#12870](https://github.com/PyTorchLightning/pytorch-lightning/pull/12870))
- Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897))

View File

@ -186,7 +186,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)):
materialize_module(child)
else:
setattr(child, name, materialize_fn())
setattr(root_module, name, materialize_fn())
return root_module

View File

@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@ -24,7 +26,7 @@ class MLP(nn.Module):
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)])
class BoringModel(LightningModule):
class SimpleBoringModel(LightningModule):
def __init__(self, num_layers: int):
super().__init__()
self.save_hyperparameters()
@ -48,7 +50,7 @@ def test_init_meta_context():
assert not is_on_meta_device(mlp)
assert not is_on_meta_device(nn.Module())
model = BoringModel(4)
model = SimpleBoringModel(4)
assert model.layer[0].weight.device.type == "meta"
materialize_module(model)
assert model.layer[0].weight.device.type == "cpu"
@ -68,3 +70,15 @@ def test_init_meta_context():
m = nn.Linear(in_features=1, out_features=1)
assert m.weight.device.type == "cpu"
@RunIf(min_torch="1.10.0", standalone=True)
def test_materialize_module_recursive_child():
"""Test materialize_module doesn't set a child recursively to a model instantiated within init_meta_context."""
with init_meta_context():
model = BoringModel()
materialize_module(model)
with pytest.raises(AttributeError, match="'Linear' object has no attribute 'layer'"):
model.layer.layer