lightning/pytorch_lightning/plugins/plugin_connector.py

178 lines
6.7 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import List, Optional, Union
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.plugins.apex import ApexPlugin
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.plugins.plugin import LightningPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class PluginConnector:
def __init__(self, trainer):
self.trainer = trainer
self.plugins = []
self.ddp_plugin = DDPPlugin()
self.cloud_environment = None
def on_trainer_init(self, plugins: Optional[Union[str, list]]):
self.plugins = plugins
if self.plugins is None:
self.plugins = []
self.plugins = self._convert_str_custom_plugins(self.plugins)
self.plugins = self._append_required_plugins(self.plugins)
self.__attach_ddp()
self.__attach_cluster()
self.__attach_amp()
self.__attach_apex()
def __attach_amp(self):
amp_plugin = self.__attach_plugin(NativeAMPPlugin)
if amp_plugin:
self.trainer.amp_backend = AMPType.NATIVE
self.trainer.precision_connector.backend = amp_plugin
def __attach_apex(self):
apex_plugin = self.__attach_plugin(ApexPlugin)
if apex_plugin:
self.trainer.amp_backend = AMPType.APEX
self.trainer.precision_connector.backend = apex_plugin
def __attach_plugin(self, plugin_type, limit=1):
count = 0
plugin_result = None
for plugin in self.plugins:
if isinstance(plugin, plugin_type):
# count the clusters
count += 1
if count > limit:
m = f'you can only use one {plugin_type.__class__} in plugins. You passed in: {count}'
raise MisconfigurationException(m)
plugin_result = plugin
return plugin_result
def __attach_ddp(self, limit=1):
count = 0
for plugin in self.plugins:
if isinstance(plugin, DDPPlugin):
# count the clusters
count += 1
if count > limit:
m = f'you can only use one DDP plugin in plugins. You passed in: {count}'
raise MisconfigurationException(m)
# set the ddp plugin
self.ddp_plugin = plugin
def __attach_cluster(self, limit=1):
num_clusters = 0
for plugin in self.plugins:
if isinstance(plugin, ClusterEnvironment):
# count the clusters
num_clusters += 1
if num_clusters > limit:
m = f'you can only use one cluster environment in plugins. You passed in: {num_clusters}'
raise MisconfigurationException(m)
# set the cluster
self.cloud_environment = plugin
def _convert_str_custom_plugins(self, plugins: Union[str, list]):
"""
Converts string inputs to corresponding supported lightning plugins.
Args:
plugins: List of plugins or string to choose lightning plugin.
Returns: List of plugins where strings are now plugins.
"""
if isinstance(plugins, str):
return [self._convert_str_to_plugin(plugins)]
return [self._convert_str_to_plugin(plugin) for plugin in plugins]
def _convert_str_to_plugin(self, plugin):
if isinstance(plugin, str):
if plugin not in LightningCustomPlugins.__members__:
raise MisconfigurationException(
f" {plugin} is not a supported lightning custom plugin."
" If you're trying to pass a custom plugin, please pass this as an object to"
" Trainer(plugins=[MyPlugin()]."
f" Supported plugins as string input: {[e.name for e in LightningCustomPlugins]}."
)
plugin_cls = LightningCustomPlugins[plugin].value
return plugin_cls(trainer=self.trainer)
return plugin
def _append_required_plugins(self, plugins: List[LightningPlugin]):
"""
Allows custom plugins to define additional plugins. This is useful for when custom plugins
need to enforce override of native amp/apex when they are enabled.
Args:
plugins: List of plugins
Returns: List of plugins containing additional plugins if needed.
Example::
class MyPlugin(DDPPlugin):
def required_plugins(self):
return [MyCustomAMPPlugin()]
# Will automatically add the necessary AMP plugin
trainer = Trainer(plugins=[MyPlugin()])
# Crash as MyPlugin enforces custom AMP plugin
trainer = Trainer(plugins=[MyPlugin(), NativeAMPPlugin()])
"""
for plugin in plugins:
required_plugins = plugin.required_plugins(amp_backend=self.trainer.amp_backend, trainer=self.trainer)
if required_plugins:
rank_zero_warn(
f'plugin {type(plugin)} has added additional required plugins as default:'
f' {[type(x) for x in required_plugins]}'
'Extend this plugin and override `required_plugins`'
'if this conflicts with your additional plugins.'
)
plugins += required_plugins
return plugins
@classmethod
def available_plugins(cls):
"""
List of all available plugins that can be string arguments to the trainer.
Returns: List of all available plugins that are supported as string arguments.
"""
return [e.name for e in LightningCustomPlugins]
class LightningCustomPlugins(Enum):
"""
String support for custom lightning plugins.
Allows easier access to custom lightning plugins from the command line.
"""
ddp_sharded = DDPShardedPlugin
native_amp = NativeAMPPlugin
apex_amp = ApexPlugin