178 lines
6.7 KiB
Python
178 lines
6.7 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from enum import Enum
|
|
from typing import List, Optional, Union
|
|
|
|
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
|
from pytorch_lightning.plugins.apex import ApexPlugin
|
|
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
|
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
|
|
from pytorch_lightning.plugins.plugin import LightningPlugin
|
|
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
|
|
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
|
|
|
|
class PluginConnector:
|
|
|
|
def __init__(self, trainer):
|
|
self.trainer = trainer
|
|
self.plugins = []
|
|
self.ddp_plugin = DDPPlugin()
|
|
self.cloud_environment = None
|
|
|
|
def on_trainer_init(self, plugins: Optional[Union[str, list]]):
|
|
self.plugins = plugins
|
|
if self.plugins is None:
|
|
self.plugins = []
|
|
self.plugins = self._convert_str_custom_plugins(self.plugins)
|
|
self.plugins = self._append_required_plugins(self.plugins)
|
|
self.__attach_ddp()
|
|
self.__attach_cluster()
|
|
self.__attach_amp()
|
|
self.__attach_apex()
|
|
|
|
def __attach_amp(self):
|
|
amp_plugin = self.__attach_plugin(NativeAMPPlugin)
|
|
if amp_plugin:
|
|
self.trainer.amp_backend = AMPType.NATIVE
|
|
self.trainer.precision_connector.backend = amp_plugin
|
|
|
|
def __attach_apex(self):
|
|
apex_plugin = self.__attach_plugin(ApexPlugin)
|
|
if apex_plugin:
|
|
self.trainer.amp_backend = AMPType.APEX
|
|
self.trainer.precision_connector.backend = apex_plugin
|
|
|
|
def __attach_plugin(self, plugin_type, limit=1):
|
|
count = 0
|
|
plugin_result = None
|
|
for plugin in self.plugins:
|
|
if isinstance(plugin, plugin_type):
|
|
|
|
# count the clusters
|
|
count += 1
|
|
if count > limit:
|
|
m = f'you can only use one {plugin_type.__class__} in plugins. You passed in: {count}'
|
|
raise MisconfigurationException(m)
|
|
|
|
plugin_result = plugin
|
|
|
|
return plugin_result
|
|
|
|
def __attach_ddp(self, limit=1):
|
|
count = 0
|
|
for plugin in self.plugins:
|
|
if isinstance(plugin, DDPPlugin):
|
|
|
|
# count the clusters
|
|
count += 1
|
|
if count > limit:
|
|
m = f'you can only use one DDP plugin in plugins. You passed in: {count}'
|
|
raise MisconfigurationException(m)
|
|
|
|
# set the ddp plugin
|
|
self.ddp_plugin = plugin
|
|
|
|
def __attach_cluster(self, limit=1):
|
|
num_clusters = 0
|
|
for plugin in self.plugins:
|
|
if isinstance(plugin, ClusterEnvironment):
|
|
|
|
# count the clusters
|
|
num_clusters += 1
|
|
if num_clusters > limit:
|
|
m = f'you can only use one cluster environment in plugins. You passed in: {num_clusters}'
|
|
raise MisconfigurationException(m)
|
|
|
|
# set the cluster
|
|
self.cloud_environment = plugin
|
|
|
|
def _convert_str_custom_plugins(self, plugins: Union[str, list]):
|
|
"""
|
|
Converts string inputs to corresponding supported lightning plugins.
|
|
Args:
|
|
plugins: List of plugins or string to choose lightning plugin.
|
|
|
|
Returns: List of plugins where strings are now plugins.
|
|
"""
|
|
if isinstance(plugins, str):
|
|
return [self._convert_str_to_plugin(plugins)]
|
|
return [self._convert_str_to_plugin(plugin) for plugin in plugins]
|
|
|
|
def _convert_str_to_plugin(self, plugin):
|
|
if isinstance(plugin, str):
|
|
if plugin not in LightningCustomPlugins.__members__:
|
|
raise MisconfigurationException(
|
|
f" {plugin} is not a supported lightning custom plugin."
|
|
" If you're trying to pass a custom plugin, please pass this as an object to"
|
|
" Trainer(plugins=[MyPlugin()]."
|
|
f" Supported plugins as string input: {[e.name for e in LightningCustomPlugins]}."
|
|
)
|
|
plugin_cls = LightningCustomPlugins[plugin].value
|
|
return plugin_cls(trainer=self.trainer)
|
|
return plugin
|
|
|
|
def _append_required_plugins(self, plugins: List[LightningPlugin]):
|
|
"""
|
|
Allows custom plugins to define additional plugins. This is useful for when custom plugins
|
|
need to enforce override of native amp/apex when they are enabled.
|
|
|
|
Args:
|
|
plugins: List of plugins
|
|
|
|
Returns: List of plugins containing additional plugins if needed.
|
|
|
|
Example::
|
|
class MyPlugin(DDPPlugin):
|
|
def required_plugins(self):
|
|
return [MyCustomAMPPlugin()]
|
|
|
|
# Will automatically add the necessary AMP plugin
|
|
trainer = Trainer(plugins=[MyPlugin()])
|
|
|
|
# Crash as MyPlugin enforces custom AMP plugin
|
|
trainer = Trainer(plugins=[MyPlugin(), NativeAMPPlugin()])
|
|
|
|
"""
|
|
for plugin in plugins:
|
|
required_plugins = plugin.required_plugins(amp_backend=self.trainer.amp_backend, trainer=self.trainer)
|
|
if required_plugins:
|
|
rank_zero_warn(
|
|
f'plugin {type(plugin)} has added additional required plugins as default:'
|
|
f' {[type(x) for x in required_plugins]}'
|
|
'Extend this plugin and override `required_plugins`'
|
|
'if this conflicts with your additional plugins.'
|
|
)
|
|
plugins += required_plugins
|
|
return plugins
|
|
|
|
@classmethod
|
|
def available_plugins(cls):
|
|
"""
|
|
List of all available plugins that can be string arguments to the trainer.
|
|
Returns: List of all available plugins that are supported as string arguments.
|
|
"""
|
|
return [e.name for e in LightningCustomPlugins]
|
|
|
|
|
|
class LightningCustomPlugins(Enum):
|
|
"""
|
|
String support for custom lightning plugins.
|
|
Allows easier access to custom lightning plugins from the command line.
|
|
"""
|
|
ddp_sharded = DDPShardedPlugin
|
|
native_amp = NativeAMPPlugin
|
|
apex_amp = ApexPlugin
|