diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 319c9053c2..8da39b92d0 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -82,6 +82,7 @@ class _TrainingTypePluginsRegistry(UserDict): def do_register(plugin: Callable) -> Callable: data["plugin"] = plugin + data["distributed_backend"] = plugin.distributed_backend self[name] = data return plugin diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 475f935fd8..a8a72c1831 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -112,6 +112,16 @@ class AcceleratorConnector(object): self._training_type_plugin: Optional[TrainingTypePlugin] = None self._cluster_environment: Optional[ClusterEnvironment] = None + plugins = plugins if plugins is not None else [] + + if isinstance(plugins, str): + plugins = [plugins] + + if not isinstance(plugins, Sequence): + plugins = [plugins] + + self.plugins = plugins + # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = pick_multiple_gpus(gpus) @@ -121,7 +131,7 @@ class AcceleratorConnector(object): self.set_distributed_mode() self.configure_slurm_ddp() - self.handle_given_plugins(plugins) + self.handle_given_plugins() self.accelerator = self.select_accelerator() @@ -148,22 +158,13 @@ class AcceleratorConnector(object): self.replace_sampler_ddp = replace_sampler_ddp - def handle_given_plugins( - self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]] - ): - plugins = plugins if plugins is not None else [] - - if isinstance(plugins, str): - plugins = [plugins] - - if not isinstance(plugins, Sequence): - plugins = [plugins] + def handle_given_plugins(self) -> None: training_type = None precision = None cluster_environment = None - for plug in plugins: + for plug in self.plugins: if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: if training_type is None: training_type = TrainingTypePluginsRegistry.get(plug) @@ -173,7 +174,7 @@ class AcceleratorConnector(object): ' Found more than 1 training type plugin:' f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}' ) - elif isinstance(plug, str): + if isinstance(plug, str): # Reset the distributed type as the user has overridden training type # via the plugins argument self._distrib_type = None @@ -310,6 +311,10 @@ class AcceleratorConnector(object): def root_gpu(self) -> Optional[int]: return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None + @property + def is_training_type_in_plugins(self) -> bool: + return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins) + @property def is_using_torchelastic(self) -> bool: """ @@ -492,7 +497,12 @@ class AcceleratorConnector(object): def set_distributed_mode(self, distributed_backend: Optional[str] = None): - if distributed_backend is not None: + if distributed_backend is None and self.is_training_type_in_plugins: + return + + if distributed_backend is not None and distributed_backend in TrainingTypePluginsRegistry: + self.distributed_backend = TrainingTypePluginsRegistry[distributed_backend]["distributed_backend"] + elif distributed_backend is not None: self.distributed_backend = distributed_backend if isinstance(self.distributed_backend, Accelerator): diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index 91d9596578..c0e112bb6f 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -23,6 +23,8 @@ def test_training_type_plugins_registry_with_new_plugin(): class TestPlugin: + distributed_backend = "test_plugin" + def __init__(self, param1, param2): self.param1 = param1 self.param2 = param2 @@ -37,6 +39,7 @@ def test_training_type_plugins_registry_with_new_plugin(): assert plugin_name in TrainingTypePluginsRegistry assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123} + assert TrainingTypePluginsRegistry[plugin_name]["distributed_backend"] == "test_plugin" assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin) TrainingTypePluginsRegistry.remove(plugin_name)