lightning/pytorch_lightning/plugins/plugins_registry.py

148 lines
5.1 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from collections import UserDict
from inspect import getmembers, isclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class _TrainingTypePluginsRegistry(UserDict):
"""
This class is a Registry that stores information about the Training Type Plugins.
The Plugins are mapped to strings. These strings are names that idenitify
a plugin, e.g., "deepspeed". It also returns Optional description and
parameters to initialize the Plugin, which were defined durng the
registration.
The motivation for having a TrainingTypePluginRegistry is to make it convenient
for the Users to try different Plugins by passing just strings
to the plugins flag to the Trainer.
Example::
@TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True)
class LightningPlugin:
def __init__(self, a, b):
...
or
TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True)
"""
def register(
self,
name: str,
plugin: Optional[Callable] = None,
description: Optional[str] = None,
override: bool = False,
**init_params: Any,
) -> Callable:
"""
Registers a plugin mapped to a name and with required metadata.
Args:
name : the name that identifies a plugin, e.g. "deepspeed_stage_3"
plugin : plugin class
description : plugin description
override : overrides the registered plugin, if True
init_params: parameters to initialize the plugin
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f"`name` must be a str, found {name}")
if name in self and not override:
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.")
data: Dict[str, Any] = {}
data["description"] = description if description is not None else ""
data["init_params"] = init_params
def do_register(plugin: Callable) -> Callable:
data["plugin"] = plugin
data["distributed_backend"] = plugin.distributed_backend
self[name] = data
return plugin
if plugin is not None:
return do_register(plugin)
return do_register
def get(self, name: str, default: Optional[Any] = None) -> Any:
"""
Calls the registered plugin with the required parameters
and returns the plugin object
Args:
name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3"
"""
if name in self:
data = self[name]
return data["plugin"](**data["init_params"])
if default is not None:
return default
err_msg = "'{}' not found in registry. Available names: {}"
available_names = ", ".join(sorted(self.keys())) or "none"
raise KeyError(err_msg.format(name, available_names))
def remove(self, name: str) -> None:
"""Removes the registered plugin by name"""
self.pop(name)
def available_plugins(self) -> List:
"""Returns a list of registered plugins"""
return list(self.keys())
def __str__(self) -> str:
return "Registered Plugins: {}".format(", ".join(self.keys()))
TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry()
def is_register_plugins_overridden(plugin: type) -> bool:
method_name = "register_plugins"
plugin_attr = getattr(plugin, method_name)
previous_super_cls = inspect.getmro(plugin)[1]
if issubclass(previous_super_cls, TrainingTypePlugin):
super_attr = getattr(previous_super_cls, method_name)
else:
return False
if hasattr(plugin_attr, "patch_loader_code"):
is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overridden = plugin_attr.__code__ is not super_attr.__code__
return is_overridden
def call_training_type_register_plugins(root: Path, base_module: str) -> None:
module = importlib.import_module(base_module)
for _, mod in getmembers(module, isclass):
if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod):
mod.register_plugins(TrainingTypePluginsRegistry)