diff --git a/CHANGELOG.md b/CHANGELOG.md index 66a64d5219..52cdb000a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011) +- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) + ## [1.2.0] - 2021-02-18 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index d32970d61f..7021081d6c 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -163,6 +163,9 @@ class AcceleratorConnector(object): for plug in plugins: if isinstance(plug, str): + # Reset the distributed type as the user has overridden training type + # via the plugins argument + self._distrib_type = None self.set_distributed_mode(plug) elif isinstance(plug, TrainingTypePlugin): @@ -196,7 +199,6 @@ class AcceleratorConnector(object): ) self._training_type_plugin = training_type - self._training_type_plugin = self.training_type_plugin self._precision_plugin = precision self._cluster_environment = cluster_environment or self.select_cluster_environment() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 76d4a597d8..82b631807c 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -23,7 +23,14 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin +from pytorch_lightning.plugins import ( + DDP2Plugin, + DDPPlugin, + DDPShardedPlugin, + DDPSpawnPlugin, + PrecisionPlugin, + SingleDevicePlugin, +) from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment from tests.helpers.boring_model import BoringModel @@ -378,3 +385,18 @@ def test_dist_backend_accelerator_mapping(device_count_mock): with pytest.raises(SystemExit): trainer.fit(model) + + +@pytest.mark.parametrize( + ["accelerator", "plugin"], + [('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')], +) +def test_plugin_accelerator_choice(accelerator, plugin): + """ + Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent. + """ + trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + + trainer = Trainer(plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)