Resolve instantiation problem with init_meta_context (#10493)

This commit is contained in:
thomas chaton 2021-11-15 19:13:01 +00:00 committed by GitHub
parent ae71284627
commit 1de3539eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 21 deletions

View File

@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))

View File

@ -17,6 +17,8 @@ from typing import Any, Optional, Union
import torch
from torch.nn import Module
import pytorch_lightning as pl
class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
@ -177,7 +179,9 @@ class DeviceDtypeModuleMixin(Module):
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
return
if device is not None:
module._device = device

View File

@ -84,7 +84,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
@ -1406,10 +1406,21 @@ class Trainer(
def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self._handle_meta_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")
def _handle_meta_model(self) -> None:
if not is_on_meta_device(self.lightning_module):
return
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")
materialize_module(self.lightning_module)
# the trainer reference is lost during materialization
self.lightning_module.trainer = proxy(self)
def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn

View File

@ -18,13 +18,14 @@ from contextlib import contextmanager
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type
import torch
from torch import nn, Tensor
from torch.nn import Module
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}
__CREATED_MODULES__ = set()
@ -237,45 +237,52 @@ def _set_meta_device() -> None:
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
continue
# if a subclass has already been stored, we should use the cache
if str(subclass) in __STORAGE_META__:
# reset the class import package to its rightfull state.
# reset the class import package to its rightful state.
mods, subclass, meta_class = __STORAGE_META__[subclass]
for mod in mods:
setattr(mod, subclass.__name__, meta_class)
continue
class _IsinstanceMetaclass(type(subclass)):
def __instancecheck__(self, instance: Any) -> bool:
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
return isinstance(instance, self.__bases__[0])
# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MetaClass(subclass):
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
@classmethod
@contextmanager
def instantiation_context(cls, materialize: bool):
def instantiation_context(cls):
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)
@classmethod
def materialize(cls, materialize_fn: Callable):
with cls.instantiation_context(materialize=True):
with cls.instantiation_context():
obj = materialize_fn()
return obj
@staticmethod
def add_subclasses(subclass):
"""This is used to unrol the instantion tree while creating the modules."""
__CREATED_MODULES__.add(subclass)
"""This is used to unroll the instantiation tree while creating the modules."""
# Don't store the LightningModule as skipped from the Meta process.
if subclass != pl.LightningModule:
__CREATED_MODULES__.add(subclass)
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MetaClass.add_subclasses(subclass.__bases__[0])
_MaterializerModule.add_subclasses(subclass.__bases__[0])
def __new__(cls, *args, **kwargs):
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context(materialize=False):
with cls.instantiation_context():
obj = init_meta(subclass, *args, **kwargs)
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
@ -294,9 +301,8 @@ def _set_meta_device() -> None:
# nn.Module class can be imported at different level and they all need to be mocked.
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
out = []
out.append(search(mod))
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
out = [search(mod)]
for name in submodules[1:]:
mod = getattr(mod, name)
out.append(search(mod))
@ -305,11 +311,11 @@ def _set_meta_device() -> None:
mods = [mod for mod in chain(*out) if mod]
# store the modules search so it doesn't have to be performed again for this class
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)
# replace all subclass by its meta form
for mod in mods:
setattr(mod, subclass.__name__, _MetaClass)
setattr(mod, subclass.__name__, _MaterializerModule)
@contextmanager
@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
_set_meta_device()
yield
_unset_meta_device()
def is_on_meta_device(module: nn.Module) -> bool:
try:
param = next(module.parameters())
return param.device.type == "meta"
except StopIteration:
return False

View File

@ -14,7 +14,7 @@
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.meta import init_meta_context, materialize_module
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
from tests.helpers.runif import RunIf
@ -31,18 +31,23 @@ class BoringModel(LightningModule):
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)])
@RunIf(min_torch="1.10.0")
@RunIf(special=True, min_torch="1.10.0")
def test_init_meta_context():
with init_meta_context():
m = nn.Linear(in_features=1, out_features=1)
assert isinstance(m, nn.Linear)
assert m.weight.device.type == "meta"
assert is_on_meta_device(m)
mlp = MLP(4)
assert mlp.layer[0].weight.device.type == "meta"
mlp = materialize_module(mlp)
assert mlp.layer[0].weight.device.type == "cpu"
assert not is_on_meta_device(mlp)
assert not is_on_meta_device(nn.Module())
model = BoringModel(4)
assert model.layer[0].weight.device.type == "meta"
materialize_module(model)