diff --git a/CHANGELOG.md b/CHANGELOG.md index 7002d16808..25cd769ea0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) +- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) + + - Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463)) diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index e02790eddd..e8b122989c 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,8 @@ from typing import Any, Optional, Union import torch from torch.nn import Module +import pytorch_lightning as pl + class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -177,7 +179,9 @@ class DeviceDtypeModuleMixin(Module): self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: - if not isinstance(module, DeviceDtypeModuleMixin): + # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't + # work when using `init_meta_context`. + if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): return if device is not None: module._device = device diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4cbb33c9b4..19efdce8e3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -84,7 +84,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.meta import materialize_module +from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import ( @@ -1406,10 +1406,21 @@ class Trainer( def _call_configure_sharded_model(self) -> None: with self.accelerator.model_sharded_context(): - materialize_module(self.lightning_module) + self._handle_meta_model() self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") + def _handle_meta_model(self) -> None: + if not is_on_meta_device(self.lightning_module): + return + + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") + + materialize_module(self.lightning_module) + # the trainer reference is lost during materialization + self.lightning_module.trainer = proxy(self) + def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 60e6cc791b..6d3c1d6b5f 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -18,13 +18,14 @@ from contextlib import contextmanager from functools import partial from itertools import chain from types import ModuleType -from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type import torch from torch import nn, Tensor from torch.nn import Module from torch.nn.modules.container import ModuleDict, ModuleList, Sequential +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 @@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module: # cache subclasses to optimize the search when resetting the meta device later on. __STORAGE_META__ = {} - __CREATED_MODULES__ = set() @@ -237,45 +237,52 @@ def _set_meta_device() -> None: for subclass in get_all_subclasses(torch.nn.modules.module.Module): - if isinstance(subclass, (Sequential, ModuleList, ModuleDict)): + if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule): continue # if a subclass has already been stored, we should use the cache if str(subclass) in __STORAGE_META__: - # reset the class import package to its rightfull state. + # reset the class import package to its rightful state. mods, subclass, meta_class = __STORAGE_META__[subclass] for mod in mods: setattr(mod, subclass.__name__, meta_class) continue + class _IsinstanceMetaclass(type(subclass)): + def __instancecheck__(self, instance: Any) -> bool: + """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" + return isinstance(instance, self.__bases__[0]) + # Create a class subclassing current `subclass` overriding its new method. # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` # version of the current subclass module - class _MetaClass(subclass): + class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): @classmethod @contextmanager - def instantiation_context(cls, materialize: bool): + def instantiation_context(cls): _unset_meta_device(from_created=True) yield _set_meta_device_populated(from_created=True) @classmethod def materialize(cls, materialize_fn: Callable): - with cls.instantiation_context(materialize=True): + with cls.instantiation_context(): obj = materialize_fn() return obj @staticmethod def add_subclasses(subclass): - """This is used to unrol the instantion tree while creating the modules.""" - __CREATED_MODULES__.add(subclass) + """This is used to unroll the instantiation tree while creating the modules.""" + # Don't store the LightningModule as skipped from the Meta process. + if subclass != pl.LightningModule: + __CREATED_MODULES__.add(subclass) if subclass.__bases__[0] != torch.nn.modules.module.Module: - _MetaClass.add_subclasses(subclass.__bases__[0]) + _MaterializerModule.add_subclasses(subclass.__bases__[0]) def __new__(cls, *args, **kwargs): subclass = cls.__bases__[0] cls.add_subclasses(subclass) - with cls.instantiation_context(materialize=False): + with cls.instantiation_context(): obj = init_meta(subclass, *args, **kwargs) obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) @@ -294,9 +301,8 @@ def _set_meta_device() -> None: # nn.Module class can be imported at different level and they all need to be mocked. # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear - # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass - out = [] - out.append(search(mod)) + # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule + out = [search(mod)] for name in submodules[1:]: mod = getattr(mod, name) out.append(search(mod)) @@ -305,11 +311,11 @@ def _set_meta_device() -> None: mods = [mod for mod in chain(*out) if mod] # store the modules search so it doesn't have to be performed again for this class - __STORAGE_META__[subclass] = (mods, subclass, _MetaClass) + __STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule) # replace all subclass by its meta form for mod in mods: - setattr(mod, subclass.__name__, _MetaClass) + setattr(mod, subclass.__name__, _MaterializerModule) @contextmanager @@ -321,3 +327,11 @@ def init_meta_context() -> Generator: _set_meta_device() yield _unset_meta_device() + + +def is_on_meta_device(module: nn.Module) -> bool: + try: + param = next(module.parameters()) + return param.device.type == "meta" + except StopIteration: + return False diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 8e36a86c3b..581b949d91 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -14,7 +14,7 @@ from torch import nn from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.meta import init_meta_context, materialize_module +from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module from tests.helpers.runif import RunIf @@ -31,18 +31,23 @@ class BoringModel(LightningModule): self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) -@RunIf(min_torch="1.10.0") +@RunIf(special=True, min_torch="1.10.0") def test_init_meta_context(): with init_meta_context(): m = nn.Linear(in_features=1, out_features=1) + assert isinstance(m, nn.Linear) assert m.weight.device.type == "meta" + assert is_on_meta_device(m) mlp = MLP(4) assert mlp.layer[0].weight.device.type == "meta" mlp = materialize_module(mlp) assert mlp.layer[0].weight.device.type == "cpu" + assert not is_on_meta_device(mlp) + assert not is_on_meta_device(nn.Module()) + model = BoringModel(4) assert model.layer[0].weight.device.type == "meta" materialize_module(model)