# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect from collections import UserDict from inspect import getmembers, isclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException class _TrainingTypePluginsRegistry(UserDict): """ This class is a Registry that stores information about the Training Type Plugins. The Plugins are mapped to strings. These strings are names that idenitify a plugin, e.g., "deepspeed". It also returns Optional description and parameters to initialize the Plugin, which were defined durng the registration. The motivation for having a TrainingTypePluginRegistry is to make it convenient for the Users to try different Plugins by passing just strings to the plugins flag to the Trainer. Example:: @TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True) class LightningPlugin: def __init__(self, a, b): ... or TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True) """ def register( self, name: str, plugin: Optional[Callable] = None, description: Optional[str] = None, override: bool = False, **init_params: Any, ) -> Callable: """ Registers a plugin mapped to a name and with required metadata. Args: name : the name that identifies a plugin, e.g. "deepspeed_stage_3" plugin : plugin class description : plugin description override : overrides the registered plugin, if True init_params: parameters to initialize the plugin """ if not (name is None or isinstance(name, str)): raise TypeError(f"`name` must be a str, found {name}") if name in self and not override: raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") data: Dict[str, Any] = {} data["description"] = description if description is not None else "" data["init_params"] = init_params def do_register(plugin: Callable) -> Callable: data["plugin"] = plugin data["distributed_backend"] = plugin.distributed_backend self[name] = data return plugin if plugin is not None: return do_register(plugin) return do_register def get(self, name: str, default: Optional[Any] = None) -> Any: """ Calls the registered plugin with the required parameters and returns the plugin object Args: name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" """ if name in self: data = self[name] return data["plugin"](**data["init_params"]) if default is not None: return default err_msg = "'{}' not found in registry. Available names: {}" available_names = ", ".join(sorted(self.keys())) or "none" raise KeyError(err_msg.format(name, available_names)) def remove(self, name: str) -> None: """Removes the registered plugin by name""" self.pop(name) def available_plugins(self) -> List: """Returns a list of registered plugins""" return list(self.keys()) def __str__(self) -> str: return "Registered Plugins: {}".format(", ".join(self.keys())) TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() def is_register_plugins_overridden(plugin: type) -> bool: method_name = "register_plugins" plugin_attr = getattr(plugin, method_name) previous_super_cls = inspect.getmro(plugin)[1] if issubclass(previous_super_cls, TrainingTypePlugin): super_attr = getattr(previous_super_cls, method_name) else: return False if hasattr(plugin_attr, "patch_loader_code"): is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__) else: is_overridden = plugin_attr.__code__ is not super_attr.__code__ return is_overridden def call_training_type_register_plugins(root: Path, base_module: str) -> None: module = importlib.import_module(base_module) for _, mod in getmembers(module, isclass): if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod): mod.register_plugins(TrainingTypePluginsRegistry)