Fix `materialize_module` recursively setting its child module (#12870)
* Don't set materialized child to child's child * Update CHANGELOG Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
db7b0361a5
commit
03039a236e
|
@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
|
||||
|
||||
|
||||
- Fixed `materialize_module` setting a module's child recursively ([#12870](https://github.com/PyTorchLightning/pytorch-lightning/pull/12870))
|
||||
|
||||
|
||||
- Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897))
|
||||
|
||||
|
||||
|
|
|
@ -186,7 +186,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
|
|||
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)):
|
||||
materialize_module(child)
|
||||
else:
|
||||
setattr(child, name, materialize_fn())
|
||||
setattr(root_module, name, materialize_fn())
|
||||
return root_module
|
||||
|
||||
|
||||
|
|
|
@ -11,10 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
from torch import nn
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
|
@ -24,7 +26,7 @@ class MLP(nn.Module):
|
|||
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)])
|
||||
|
||||
|
||||
class BoringModel(LightningModule):
|
||||
class SimpleBoringModel(LightningModule):
|
||||
def __init__(self, num_layers: int):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
@ -48,7 +50,7 @@ def test_init_meta_context():
|
|||
assert not is_on_meta_device(mlp)
|
||||
assert not is_on_meta_device(nn.Module())
|
||||
|
||||
model = BoringModel(4)
|
||||
model = SimpleBoringModel(4)
|
||||
assert model.layer[0].weight.device.type == "meta"
|
||||
materialize_module(model)
|
||||
assert model.layer[0].weight.device.type == "cpu"
|
||||
|
@ -68,3 +70,15 @@ def test_init_meta_context():
|
|||
|
||||
m = nn.Linear(in_features=1, out_features=1)
|
||||
assert m.weight.device.type == "cpu"
|
||||
|
||||
|
||||
@RunIf(min_torch="1.10.0", standalone=True)
|
||||
def test_materialize_module_recursive_child():
|
||||
"""Test materialize_module doesn't set a child recursively to a model instantiated within init_meta_context."""
|
||||
with init_meta_context():
|
||||
model = BoringModel()
|
||||
|
||||
materialize_module(model)
|
||||
|
||||
with pytest.raises(AttributeError, match="'Linear' object has no attribute 'layer'"):
|
||||
model.layer.layer
|
||||
|
|
Loading…
Reference in New Issue