lightning/pytorch_lightning/plugins/plugin.py

29 lines
947 B
Python
Raw Normal View History

Allow string plugins (#4888) * Allow plugin to be chosen via string * Fix implementation, add tests * Fix codefactor issues * Added missing env patch * Skip test for windows * Reword reason * Add skip to invalid test * Create required_plugins function, move sharded amp requirement to plugin * Pass AMPType, fix setter for apex * Better doc strings * Add exception when using apex * Add trainer available_plugins function, warn user when plugins have been added automatically with option to override behaviour * Fixed pep8 indent * Fix codefactor issues * Add env variables * Update pytorch_lightning/cluster_environments/cluster_environment.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed code review * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed more code review feedback * Fixed docstrings * Swapped to verbose runtime error * Apply suggestions from code review * Apply suggestions from code review * Update pytorch_lightning/plugins/sharded_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Change name * Pass trainer to plugins that may require it * Fix sharded plugin * Added test to ensure string sharded works * Removed trainer typing as this breaks pep8 * Fixed doc issues * Fixed tests Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-12-01 20:30:49 +00:00
from pytorch_lightning.utilities import AMPType
class LightningPlugin:
"""
Defines base class for Plugins. Plugins represent functionality that can be injected into the lightning codebase.
"""
def required_plugins(self, amp_backend: AMPType, trainer) -> list:
"""
Override to define additional required plugins. This is useful for when custom plugins
need to enforce override of other plugins.
Returns: Optional list of plugins containing additional plugins.
Example::
class MyPlugin(DDPPlugin):
def required_plugins(self):
return [MyCustomAMPPlugin()]
# Will automatically add the necessary AMP plugin
trainer = Trainer(plugins=[MyPlugin()])
# Crash as MyPlugin enforces custom AMP plugin
trainer = Trainer(plugins=[MyPlugin(), NativeAMPPlugin()])
"""
return []