Resolve instantiation problem with init_meta_context (#10493)
This commit is contained in:
parent
ae71284627
commit
1de3539eac
|
@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
|
||||
|
||||
|
||||
- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
|
||||
|
||||
|
||||
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ from typing import Any, Optional, Union
|
|||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class DeviceDtypeModuleMixin(Module):
|
||||
__jit_unused_properties__ = ["device", "dtype"]
|
||||
|
@ -177,7 +179,9 @@ class DeviceDtypeModuleMixin(Module):
|
|||
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
|
||||
) -> None:
|
||||
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
|
||||
if not isinstance(module, DeviceDtypeModuleMixin):
|
||||
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
|
||||
# work when using `init_meta_context`.
|
||||
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
|
||||
return
|
||||
if device is not None:
|
||||
module._device = device
|
||||
|
|
|
@ -84,7 +84,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem
|
|||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.meta import materialize_module
|
||||
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import (
|
||||
|
@ -1406,10 +1406,21 @@ class Trainer(
|
|||
|
||||
def _call_configure_sharded_model(self) -> None:
|
||||
with self.accelerator.model_sharded_context():
|
||||
materialize_module(self.lightning_module)
|
||||
self._handle_meta_model()
|
||||
self.call_hook("configure_sharded_model")
|
||||
self.call_hook("on_configure_sharded_model")
|
||||
|
||||
def _handle_meta_model(self) -> None:
|
||||
if not is_on_meta_device(self.lightning_module):
|
||||
return
|
||||
|
||||
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
|
||||
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")
|
||||
|
||||
materialize_module(self.lightning_module)
|
||||
# the trainer reference is lost during materialization
|
||||
self.lightning_module.trainer = proxy(self)
|
||||
|
||||
def _call_teardown_hook(self) -> None:
|
||||
fn = self.state.fn._setup_fn
|
||||
|
||||
|
|
|
@ -18,13 +18,14 @@ from contextlib import contextmanager
|
|||
from functools import partial
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
|
||||
|
@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
|
|||
|
||||
# cache subclasses to optimize the search when resetting the meta device later on.
|
||||
__STORAGE_META__ = {}
|
||||
|
||||
__CREATED_MODULES__ = set()
|
||||
|
||||
|
||||
|
@ -237,45 +237,52 @@ def _set_meta_device() -> None:
|
|||
|
||||
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
|
||||
|
||||
if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
|
||||
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
|
||||
continue
|
||||
|
||||
# if a subclass has already been stored, we should use the cache
|
||||
if str(subclass) in __STORAGE_META__:
|
||||
# reset the class import package to its rightfull state.
|
||||
# reset the class import package to its rightful state.
|
||||
mods, subclass, meta_class = __STORAGE_META__[subclass]
|
||||
for mod in mods:
|
||||
setattr(mod, subclass.__name__, meta_class)
|
||||
continue
|
||||
|
||||
class _IsinstanceMetaclass(type(subclass)):
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
|
||||
return isinstance(instance, self.__bases__[0])
|
||||
|
||||
# Create a class subclassing current `subclass` overriding its new method.
|
||||
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
|
||||
# version of the current subclass module
|
||||
class _MetaClass(subclass):
|
||||
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def instantiation_context(cls, materialize: bool):
|
||||
def instantiation_context(cls):
|
||||
_unset_meta_device(from_created=True)
|
||||
yield
|
||||
_set_meta_device_populated(from_created=True)
|
||||
|
||||
@classmethod
|
||||
def materialize(cls, materialize_fn: Callable):
|
||||
with cls.instantiation_context(materialize=True):
|
||||
with cls.instantiation_context():
|
||||
obj = materialize_fn()
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def add_subclasses(subclass):
|
||||
"""This is used to unrol the instantion tree while creating the modules."""
|
||||
__CREATED_MODULES__.add(subclass)
|
||||
"""This is used to unroll the instantiation tree while creating the modules."""
|
||||
# Don't store the LightningModule as skipped from the Meta process.
|
||||
if subclass != pl.LightningModule:
|
||||
__CREATED_MODULES__.add(subclass)
|
||||
if subclass.__bases__[0] != torch.nn.modules.module.Module:
|
||||
_MetaClass.add_subclasses(subclass.__bases__[0])
|
||||
_MaterializerModule.add_subclasses(subclass.__bases__[0])
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
subclass = cls.__bases__[0]
|
||||
cls.add_subclasses(subclass)
|
||||
with cls.instantiation_context(materialize=False):
|
||||
with cls.instantiation_context():
|
||||
obj = init_meta(subclass, *args, **kwargs)
|
||||
|
||||
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
|
||||
|
@ -294,9 +301,8 @@ def _set_meta_device() -> None:
|
|||
# nn.Module class can be imported at different level and they all need to be mocked.
|
||||
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
|
||||
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
|
||||
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
|
||||
out = []
|
||||
out.append(search(mod))
|
||||
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
|
||||
out = [search(mod)]
|
||||
for name in submodules[1:]:
|
||||
mod = getattr(mod, name)
|
||||
out.append(search(mod))
|
||||
|
@ -305,11 +311,11 @@ def _set_meta_device() -> None:
|
|||
mods = [mod for mod in chain(*out) if mod]
|
||||
|
||||
# store the modules search so it doesn't have to be performed again for this class
|
||||
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
|
||||
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)
|
||||
|
||||
# replace all subclass by its meta form
|
||||
for mod in mods:
|
||||
setattr(mod, subclass.__name__, _MetaClass)
|
||||
setattr(mod, subclass.__name__, _MaterializerModule)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
|
|||
_set_meta_device()
|
||||
yield
|
||||
_unset_meta_device()
|
||||
|
||||
|
||||
def is_on_meta_device(module: nn.Module) -> bool:
|
||||
try:
|
||||
param = next(module.parameters())
|
||||
return param.device.type == "meta"
|
||||
except StopIteration:
|
||||
return False
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from torch import nn
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.meta import init_meta_context, materialize_module
|
||||
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
|
@ -31,18 +31,23 @@ class BoringModel(LightningModule):
|
|||
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)])
|
||||
|
||||
|
||||
@RunIf(min_torch="1.10.0")
|
||||
@RunIf(special=True, min_torch="1.10.0")
|
||||
def test_init_meta_context():
|
||||
|
||||
with init_meta_context():
|
||||
m = nn.Linear(in_features=1, out_features=1)
|
||||
assert isinstance(m, nn.Linear)
|
||||
assert m.weight.device.type == "meta"
|
||||
assert is_on_meta_device(m)
|
||||
mlp = MLP(4)
|
||||
assert mlp.layer[0].weight.device.type == "meta"
|
||||
|
||||
mlp = materialize_module(mlp)
|
||||
assert mlp.layer[0].weight.device.type == "cpu"
|
||||
|
||||
assert not is_on_meta_device(mlp)
|
||||
assert not is_on_meta_device(nn.Module())
|
||||
|
||||
model = BoringModel(4)
|
||||
assert model.layer[0].weight.device.type == "meta"
|
||||
materialize_module(model)
|
||||
|
|
Loading…
Reference in New Issue