[Hot Fix] Give priority to plugins to set distributed mode, and then accelerator (#6089)

* Give priority to plugins to set distributed mode, and then accelerator

* Add CHANGELOG.md

* Update CHANGELOG.md

* Remove very scary line

* Ensure we set cluster environment after slurm configured if necessary

* Simplify the fix with a reset

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Sean Naren 2021-02-20 12:58:54 +00:00 committed by GitHub
parent 3bdc0673ea
commit 97a81c3cfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 2 deletions

View File

@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
## [1.2.0] - 2021-02-18

View File

@ -163,6 +163,9 @@ class AcceleratorConnector(object):
for plug in plugins:
if isinstance(plug, str):
# Reset the distributed type as the user has overridden training type
# via the plugins argument
self._distrib_type = None
self.set_distributed_mode(plug)
elif isinstance(plug, TrainingTypePlugin):
@ -196,7 +199,6 @@ class AcceleratorConnector(object):
)
self._training_type_plugin = training_type
self._training_type_plugin = self.training_type_plugin
self._precision_plugin = precision
self._cluster_environment = cluster_environment or self.select_cluster_environment()

View File

@ -23,7 +23,14 @@ from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin
from pytorch_lightning.plugins import (
DDP2Plugin,
DDPPlugin,
DDPShardedPlugin,
DDPSpawnPlugin,
PrecisionPlugin,
SingleDevicePlugin,
)
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from tests.helpers.boring_model import BoringModel
@ -378,3 +385,18 @@ def test_dist_backend_accelerator_mapping(device_count_mock):
with pytest.raises(SystemExit):
trainer.fit(model)
@pytest.mark.parametrize(
["accelerator", "plugin"],
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
)
def test_plugin_accelerator_choice(accelerator, plugin):
"""
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
"""
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
trainer = Trainer(plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)