Update Accelerator Connector for Registry (#7214)
This commit is contained in:
parent
b7a444883c
commit
6d7c6d6403
|
@ -82,6 +82,7 @@ class _TrainingTypePluginsRegistry(UserDict):
|
|||
|
||||
def do_register(plugin: Callable) -> Callable:
|
||||
data["plugin"] = plugin
|
||||
data["distributed_backend"] = plugin.distributed_backend
|
||||
self[name] = data
|
||||
return plugin
|
||||
|
||||
|
|
|
@ -112,6 +112,16 @@ class AcceleratorConnector(object):
|
|||
self._training_type_plugin: Optional[TrainingTypePlugin] = None
|
||||
self._cluster_environment: Optional[ClusterEnvironment] = None
|
||||
|
||||
plugins = plugins if plugins is not None else []
|
||||
|
||||
if isinstance(plugins, str):
|
||||
plugins = [plugins]
|
||||
|
||||
if not isinstance(plugins, Sequence):
|
||||
plugins = [plugins]
|
||||
|
||||
self.plugins = plugins
|
||||
|
||||
# for gpus allow int, string and gpu list
|
||||
if auto_select_gpus and isinstance(gpus, int):
|
||||
self.gpus = pick_multiple_gpus(gpus)
|
||||
|
@ -121,7 +131,7 @@ class AcceleratorConnector(object):
|
|||
self.set_distributed_mode()
|
||||
self.configure_slurm_ddp()
|
||||
|
||||
self.handle_given_plugins(plugins)
|
||||
self.handle_given_plugins()
|
||||
|
||||
self.accelerator = self.select_accelerator()
|
||||
|
||||
|
@ -148,22 +158,13 @@ class AcceleratorConnector(object):
|
|||
|
||||
self.replace_sampler_ddp = replace_sampler_ddp
|
||||
|
||||
def handle_given_plugins(
|
||||
self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]]
|
||||
):
|
||||
plugins = plugins if plugins is not None else []
|
||||
|
||||
if isinstance(plugins, str):
|
||||
plugins = [plugins]
|
||||
|
||||
if not isinstance(plugins, Sequence):
|
||||
plugins = [plugins]
|
||||
def handle_given_plugins(self) -> None:
|
||||
|
||||
training_type = None
|
||||
precision = None
|
||||
cluster_environment = None
|
||||
|
||||
for plug in plugins:
|
||||
for plug in self.plugins:
|
||||
if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
|
||||
if training_type is None:
|
||||
training_type = TrainingTypePluginsRegistry.get(plug)
|
||||
|
@ -173,7 +174,7 @@ class AcceleratorConnector(object):
|
|||
' Found more than 1 training type plugin:'
|
||||
f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
|
||||
)
|
||||
elif isinstance(plug, str):
|
||||
if isinstance(plug, str):
|
||||
# Reset the distributed type as the user has overridden training type
|
||||
# via the plugins argument
|
||||
self._distrib_type = None
|
||||
|
@ -310,6 +311,10 @@ class AcceleratorConnector(object):
|
|||
def root_gpu(self) -> Optional[int]:
|
||||
return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None
|
||||
|
||||
@property
|
||||
def is_training_type_in_plugins(self) -> bool:
|
||||
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)
|
||||
|
||||
@property
|
||||
def is_using_torchelastic(self) -> bool:
|
||||
"""
|
||||
|
@ -492,7 +497,12 @@ class AcceleratorConnector(object):
|
|||
|
||||
def set_distributed_mode(self, distributed_backend: Optional[str] = None):
|
||||
|
||||
if distributed_backend is not None:
|
||||
if distributed_backend is None and self.is_training_type_in_plugins:
|
||||
return
|
||||
|
||||
if distributed_backend is not None and distributed_backend in TrainingTypePluginsRegistry:
|
||||
self.distributed_backend = TrainingTypePluginsRegistry[distributed_backend]["distributed_backend"]
|
||||
elif distributed_backend is not None:
|
||||
self.distributed_backend = distributed_backend
|
||||
|
||||
if isinstance(self.distributed_backend, Accelerator):
|
||||
|
|
|
@ -23,6 +23,8 @@ def test_training_type_plugins_registry_with_new_plugin():
|
|||
|
||||
class TestPlugin:
|
||||
|
||||
distributed_backend = "test_plugin"
|
||||
|
||||
def __init__(self, param1, param2):
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
|
@ -37,6 +39,7 @@ def test_training_type_plugins_registry_with_new_plugin():
|
|||
assert plugin_name in TrainingTypePluginsRegistry
|
||||
assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description
|
||||
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123}
|
||||
assert TrainingTypePluginsRegistry[plugin_name]["distributed_backend"] == "test_plugin"
|
||||
assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin)
|
||||
|
||||
TrainingTypePluginsRegistry.remove(plugin_name)
|
||||
|
|
Loading…
Reference in New Issue