Update Accelerator Connector for Registry (#7214)

This commit is contained in:
Kaushik B 2021-05-04 02:33:21 +05:30 committed by GitHub
parent b7a444883c
commit 6d7c6d6403
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 14 deletions

View File

@ -82,6 +82,7 @@ class _TrainingTypePluginsRegistry(UserDict):
def do_register(plugin: Callable) -> Callable:
data["plugin"] = plugin
data["distributed_backend"] = plugin.distributed_backend
self[name] = data
return plugin

View File

@ -112,6 +112,16 @@ class AcceleratorConnector(object):
self._training_type_plugin: Optional[TrainingTypePlugin] = None
self._cluster_environment: Optional[ClusterEnvironment] = None
plugins = plugins if plugins is not None else []
if isinstance(plugins, str):
plugins = [plugins]
if not isinstance(plugins, Sequence):
plugins = [plugins]
self.plugins = plugins
# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = pick_multiple_gpus(gpus)
@ -121,7 +131,7 @@ class AcceleratorConnector(object):
self.set_distributed_mode()
self.configure_slurm_ddp()
self.handle_given_plugins(plugins)
self.handle_given_plugins()
self.accelerator = self.select_accelerator()
@ -148,22 +158,13 @@ class AcceleratorConnector(object):
self.replace_sampler_ddp = replace_sampler_ddp
def handle_given_plugins(
self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]]
):
plugins = plugins if plugins is not None else []
if isinstance(plugins, str):
plugins = [plugins]
if not isinstance(plugins, Sequence):
plugins = [plugins]
def handle_given_plugins(self) -> None:
training_type = None
precision = None
cluster_environment = None
for plug in plugins:
for plug in self.plugins:
if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
if training_type is None:
training_type = TrainingTypePluginsRegistry.get(plug)
@ -173,7 +174,7 @@ class AcceleratorConnector(object):
' Found more than 1 training type plugin:'
f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
)
elif isinstance(plug, str):
if isinstance(plug, str):
# Reset the distributed type as the user has overridden training type
# via the plugins argument
self._distrib_type = None
@ -310,6 +311,10 @@ class AcceleratorConnector(object):
def root_gpu(self) -> Optional[int]:
return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None
@property
def is_training_type_in_plugins(self) -> bool:
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)
@property
def is_using_torchelastic(self) -> bool:
"""
@ -492,7 +497,12 @@ class AcceleratorConnector(object):
def set_distributed_mode(self, distributed_backend: Optional[str] = None):
if distributed_backend is not None:
if distributed_backend is None and self.is_training_type_in_plugins:
return
if distributed_backend is not None and distributed_backend in TrainingTypePluginsRegistry:
self.distributed_backend = TrainingTypePluginsRegistry[distributed_backend]["distributed_backend"]
elif distributed_backend is not None:
self.distributed_backend = distributed_backend
if isinstance(self.distributed_backend, Accelerator):

View File

@ -23,6 +23,8 @@ def test_training_type_plugins_registry_with_new_plugin():
class TestPlugin:
distributed_backend = "test_plugin"
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
@ -37,6 +39,7 @@ def test_training_type_plugins_registry_with_new_plugin():
assert plugin_name in TrainingTypePluginsRegistry
assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123}
assert TrainingTypePluginsRegistry[plugin_name]["distributed_backend"] == "test_plugin"
assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin)
TrainingTypePluginsRegistry.remove(plugin_name)