338 lines
12 KiB
Python
338 lines
12 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import importlib
|
|
import inspect
|
|
import threading
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
from itertools import chain
|
|
from types import ModuleType
|
|
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type
|
|
|
|
import torch
|
|
from torch import nn, Tensor
|
|
from torch.nn import Module
|
|
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential
|
|
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
|
|
|
|
if _TORCH_GREATER_EQUAL_1_10:
|
|
from torch._C import _DisableTorchDispatch # type: ignore[attr-defined]
|
|
|
|
####################################################################
|
|
# BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
|
|
# TODO: Removed once merged and released on PyTorch side #
|
|
####################################################################
|
|
|
|
@contextmanager
|
|
def enable_python_mode(cls) -> Iterator[None]:
|
|
if not hasattr(cls, "__torch_dispatch__"):
|
|
raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod")
|
|
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
|
|
raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass")
|
|
torch._C._enter_python_mode(cls)
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._C._exit_python_mode()
|
|
|
|
_tls = threading.local()
|
|
_tls.in_call = False
|
|
|
|
@contextmanager
|
|
def _no_dispatch() -> Iterator[None]:
|
|
"""Temporarily disables the Python dispatch mode."""
|
|
guard = _DisableTorchDispatch() # noqa F841
|
|
try:
|
|
yield
|
|
finally:
|
|
del guard
|
|
|
|
def _handle_arange(func, args, kwargs):
|
|
kwargs["device"] = torch.device("cpu")
|
|
return torch.empty_like(func(*args, **kwargs), device="meta")
|
|
|
|
def _handle_tril(func, args, kwargs):
|
|
if args and isinstance(args[0], Tensor):
|
|
return torch.empty_like(args[0], device="meta")
|
|
|
|
return NotImplemented
|
|
|
|
class _MetaContext(Tensor):
|
|
_op_handlers: Dict[Callable, Callable] = {}
|
|
|
|
@classmethod
|
|
def _ensure_handlers_initialized(cls) -> None:
|
|
if cls._op_handlers:
|
|
return
|
|
|
|
cls._op_handlers.update(
|
|
{
|
|
torch.ops.aten.arange: _handle_arange,
|
|
torch.ops.aten.tril: _handle_tril,
|
|
}
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
cls._ensure_handlers_initialized()
|
|
|
|
op_handler: Optional[Callable]
|
|
|
|
try:
|
|
op_handler = cls._op_handlers[func]
|
|
except KeyError:
|
|
op_handler = None
|
|
|
|
with _no_dispatch():
|
|
if op_handler:
|
|
result = op_handler(func, args, kwargs)
|
|
if result is not NotImplemented:
|
|
return result
|
|
|
|
if "device" in kwargs:
|
|
kwargs["device"] = torch.device("meta")
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module:
|
|
def create_instance(module=None) -> Module:
|
|
if module:
|
|
module.__init__(*args, **kwargs)
|
|
return module
|
|
return module_fn(*args, **kwargs)
|
|
|
|
if _tls.in_call:
|
|
module = create_instance()
|
|
else:
|
|
_tls.in_call = True
|
|
try:
|
|
with enable_python_mode(_MetaContext):
|
|
module = create_instance()
|
|
finally:
|
|
_tls.in_call = False
|
|
|
|
module.materialize = partial(create_instance, module=module) # type: ignore[assignment]
|
|
|
|
return module
|
|
|
|
def is_meta_init() -> bool:
|
|
"""Indicates whether the module is being instantiated by ``init_meta()``."""
|
|
return _tls.in_call
|
|
|
|
####################################################################
|
|
# ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
|
|
# TODO: Removed once merged and released on PyTorch side #
|
|
####################################################################
|
|
|
|
else:
|
|
|
|
def init_meta(*_, **__):
|
|
if not _TORCH_GREATER_EQUAL_1_10:
|
|
return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
|
|
|
|
|
|
# https://stackoverflow.com/a/63851681/9201239
|
|
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]:
|
|
subclass_list = []
|
|
|
|
def recurse(cl):
|
|
for subclass in cl.__subclasses__():
|
|
subclass_list.append(subclass)
|
|
recurse(subclass)
|
|
|
|
recurse(cls)
|
|
|
|
return set(subclass_list)
|
|
|
|
|
|
def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None:
|
|
*path, name = prefix.split(".")
|
|
for p in path:
|
|
root_module = getattr(root_module, p)
|
|
|
|
try:
|
|
index = int(name)
|
|
root_module[index] = materialized_module
|
|
except ValueError:
|
|
setattr(root_module, name, materialized_module)
|
|
|
|
|
|
def materialize_module(root_module: nn.Module) -> nn.Module:
|
|
"""This utility performs an in-place operation by materialize a module and its children."""
|
|
if not _TORCH_GREATER_EQUAL_1_10:
|
|
return root_module
|
|
|
|
materialize_fn = getattr(root_module, "materialize", None)
|
|
if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)):
|
|
return materialize_fn()
|
|
|
|
for name, child in root_module.named_children():
|
|
materialize_fn = getattr(child, "materialize", None)
|
|
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)):
|
|
materialize_module(child)
|
|
else:
|
|
setattr(child, name, materialize_fn())
|
|
return root_module
|
|
|
|
|
|
# cache subclasses to optimize the search when resetting the meta device later on.
|
|
__STORAGE_META__ = {}
|
|
__CREATED_MODULES__ = set()
|
|
|
|
|
|
def _unset_meta_device(from_created: bool = False) -> None:
|
|
"""Replace all meta module by their original version."""
|
|
if not _TORCH_GREATER_EQUAL_1_10:
|
|
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
|
|
|
|
if from_created:
|
|
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
|
|
else:
|
|
values = __STORAGE_META__.values()
|
|
|
|
for mods, subclass, _ in values:
|
|
for mod in mods:
|
|
setattr(mod, subclass.__name__, subclass)
|
|
|
|
|
|
def _set_meta_device_populated(from_created: bool = False) -> None:
|
|
"""Replace all meta module by their original version."""
|
|
if not _TORCH_GREATER_EQUAL_1_10:
|
|
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
|
|
|
|
if from_created:
|
|
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
|
|
else:
|
|
values = __STORAGE_META__.values()
|
|
|
|
for mods, subclass, meta_class in values:
|
|
for mod in mods:
|
|
setattr(mod, subclass.__name__, meta_class)
|
|
|
|
|
|
def _set_meta_device() -> None:
|
|
"""Replace all torch.nn.Module by their meta replacement."""
|
|
|
|
if not _TORCH_GREATER_EQUAL_1_10:
|
|
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
|
|
|
|
# Author note: This can be optimized further by searching all subclasses at once.
|
|
# Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance
|
|
# and m the number of all subclasses belonging to its subclass module.
|
|
|
|
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
|
|
|
|
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
|
|
continue
|
|
|
|
# if a subclass has already been stored, we should use the cache
|
|
if str(subclass) in __STORAGE_META__:
|
|
# reset the class import package to its rightful state.
|
|
mods, subclass, meta_class = __STORAGE_META__[subclass]
|
|
for mod in mods:
|
|
setattr(mod, subclass.__name__, meta_class)
|
|
continue
|
|
|
|
class _IsinstanceMetaclass(type(subclass)):
|
|
def __instancecheck__(self, instance: Any) -> bool:
|
|
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
|
|
return isinstance(instance, self.__bases__[0])
|
|
|
|
# Create a class subclassing current `subclass` overriding its new method.
|
|
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
|
|
# version of the current subclass module
|
|
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
|
|
@classmethod
|
|
@contextmanager
|
|
def instantiation_context(cls):
|
|
_unset_meta_device(from_created=True)
|
|
yield
|
|
_set_meta_device_populated(from_created=True)
|
|
|
|
@classmethod
|
|
def materialize(cls, materialize_fn: Callable):
|
|
with cls.instantiation_context():
|
|
obj = materialize_fn()
|
|
return obj
|
|
|
|
@staticmethod
|
|
def add_subclasses(subclass):
|
|
"""This is used to unroll the instantiation tree while creating the modules."""
|
|
# Don't store the LightningModule as skipped from the Meta process.
|
|
if subclass != pl.LightningModule:
|
|
__CREATED_MODULES__.add(subclass)
|
|
if subclass.__bases__[0] != torch.nn.modules.module.Module:
|
|
_MaterializerModule.add_subclasses(subclass.__bases__[0])
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
subclass = cls.__bases__[0]
|
|
cls.add_subclasses(subclass)
|
|
with cls.instantiation_context():
|
|
obj = init_meta(subclass, *args, **kwargs)
|
|
|
|
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
|
|
return obj
|
|
|
|
def search(mod: ModuleType) -> List[ModuleType]:
|
|
out = []
|
|
for _, obj in inspect.getmembers(mod):
|
|
if obj == subclass:
|
|
out.append(mod)
|
|
return out
|
|
|
|
submodules = subclass.__module__.split(".")
|
|
mod = importlib.import_module(submodules[0])
|
|
|
|
# nn.Module class can be imported at different level and they all need to be mocked.
|
|
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
|
|
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
|
|
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
|
|
out = [search(mod)]
|
|
for name in submodules[1:]:
|
|
mod = getattr(mod, name)
|
|
out.append(search(mod))
|
|
|
|
# drop empty module
|
|
mods = [mod for mod in chain(*out) if mod]
|
|
|
|
# store the modules search so it doesn't have to be performed again for this class
|
|
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)
|
|
|
|
# replace all subclass by its meta form
|
|
for mod in mods:
|
|
setattr(mod, subclass.__name__, _MaterializerModule)
|
|
|
|
|
|
@contextmanager
|
|
def init_meta_context() -> Generator:
|
|
rank_zero_warn(
|
|
"Be aware this feature is highly experimental and there are a number of weird edge cases "
|
|
"where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11."
|
|
)
|
|
_set_meta_device()
|
|
yield
|
|
_unset_meta_device()
|
|
|
|
|
|
def is_on_meta_device(module: nn.Module) -> bool:
|
|
try:
|
|
param = next(module.parameters())
|
|
return param.device.type == "meta"
|
|
except StopIteration:
|
|
return False
|