diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index ea0713b86a..c0ff298b6c 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -9,3 +9,5 @@ from pytorch_lightning.accelerators.tpu_backend import TPUBackend from pytorch_lightning.accelerators.horovod_backend import HorovodBackend from pytorch_lightning.accelerators.ddp_slurm_backend import DDPSLURMBackend from pytorch_lightning.accelerators.ddp_torchelastic_backend import DDPTorchElasticBackend +from pytorch_lightning.accelerators.ddp_cpu_torchelastic_backend import DDPCPUTorchElasticBackend +from pytorch_lightning.accelerators.ddp_cpu_slurm_backend import DDPCPUSLURMBackend diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 4fdc1e2089..a41ced049c 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -158,6 +158,9 @@ class AcceleratorConnector: use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn" use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu" + use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic() + use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks + # ddp script mode uses the same flags as TE # TODO: decouple from TE if os.environ.get('PL_DDP_PID', False): @@ -167,9 +170,15 @@ class AcceleratorConnector: if self.trainer.use_ddp2: accelerator_backend = accelerators.DDP2Backend(self.trainer) + elif use_ddp_cpu_slurm: + accelerator_backend = accelerators.DDPCPUSLURMBackend(self.trainer) + elif use_slurm_ddp: accelerator_backend = accelerators.DDPSLURMBackend(self.trainer) + elif use_ddp_cpu_torch_elastic: + accelerator_backend = accelerators.DDPCPUTorchElasticBackend(self.trainer) + elif use_torchelastic_ddp: accelerator_backend = accelerators.DDPTorchElasticBackend(self.trainer) diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index af91ec9313..79aad9aef3 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: @@ -93,6 +94,9 @@ class DDPBackend(Accelerator): # when the trainer script was called the device has already been scoped by the time # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone # but forward the GPUs selected via environment variables + if self.trainer.data_parallel_device_ids is None: + raise MisconfigurationException('you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)') + os.environ['PL_TRAINER_GPUS'] = ','.join([str(i) for i in self.trainer.data_parallel_device_ids]) os.environ['PL_IN_DDP_SUBPROCESS'] = '1' diff --git a/pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py b/pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py new file mode 100644 index 0000000000..8f6e9065ca --- /dev/null +++ b/pytorch_lightning/accelerators/ddp_cpu_slurm_backend.py @@ -0,0 +1,173 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import os +import torch +import torch.distributed as torch_distrib +import torch.distributed as dist + +from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.distributed.dist import LightningDistributed + + +try: + from hydra.utils import to_absolute_path, get_original_cwd + from hydra.core.hydra_config import HydraConfig +except ImportError: + HYDRA_AVAILABLE = False +else: + HYDRA_AVAILABLE = True + + +# ------------------------------------------- +# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!! +# TEMP CLASS WHILE WE DECOUPLE TE FROM DDP +# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!! +# ------------------------------------------- +class DDPCPUSLURMBackend(Accelerator): + + def __init__(self, trainer, cluster_environment=None): + super().__init__(trainer, cluster_environment) + self.task_idx = None + self._has_spawned_children = False + self.dist = LightningDistributed() + + def setup(self, model): + self.trainer.model = model + self.task_idx = int(os.environ['SLURM_LOCALID']) + + def train(self): + model = self.trainer.model + self.ddp_train(process_idx=self.task_idx, model=model) + + def set_world_ranks(self, process_idx): + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + def model_to_device(self, model, process_idx): + model.cpu() + + def get_device_ids(self): + device_ids = None + return device_ids + + def training_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model(*args) + else: + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + + def barrier(self, name: str = None): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) + dist.all_reduce(stop, op=dist.reduce_op.SUM) + dist.barrier() + should_stop = stop == self.trainer.world_size + return should_stop + + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + + def ddp_train(self, process_idx, model): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + + Returns: + + """ + # determine which process we are and world size + self.set_world_ranks(process_idx) + + # toggle prog bar + if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + self.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend} (TORCH_ELASTIC)') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # call sync_bn before .cuda(), configure_apex and configure_ddp + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) + + # move the model to the correct device + self.model_to_device(model, process_idx) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.setup_optimizers(model) + + # set model properties before going into wrapper + self.trainer.model_connector.copy_trainer_model_properties(model) + + # 16-bit + model = self.trainer.precision_connector.connect(model) + + # device ids change depending on the DDP setup + device_ids = self.get_device_ids() + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # set up training routine + self.trainer.train_loop.setup_training(model) + + # train or test + results = self.train_or_test() + + # clean up memory + torch.cuda.empty_cache() + + return results diff --git a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py new file mode 100644 index 0000000000..b47d9a6ea3 --- /dev/null +++ b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_backend.py @@ -0,0 +1,173 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import os +import torch +import torch.distributed as torch_distrib +import torch.distributed as dist + +from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.distributed.dist import LightningDistributed + + +try: + from hydra.utils import to_absolute_path, get_original_cwd + from hydra.core.hydra_config import HydraConfig +except ImportError: + HYDRA_AVAILABLE = False +else: + HYDRA_AVAILABLE = True + + +# ------------------------------------------- +# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!! +# TEMP CLASS WHILE WE DECOUPLE TE FROM DDP +# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!! +# ------------------------------------------- +class DDPCPUTorchElasticBackend(Accelerator): + + def __init__(self, trainer, cluster_environment=None): + super().__init__(trainer, cluster_environment) + self.task_idx = None + self._has_spawned_children = False + self.dist = LightningDistributed() + + def setup(self, model): + self.trainer.model = model + self.task_idx = int(os.environ['LOCAL_RANK']) + + def train(self): + model = self.trainer.model + self.ddp_train(process_idx=self.task_idx, model=model) + + def set_world_ranks(self, process_idx): + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + def model_to_device(self, model, process_idx): + model.cpu() + + def get_device_ids(self): + device_ids = None + return device_ids + + def training_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model(*args) + else: + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + + def barrier(self, name: str = None): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) + dist.all_reduce(stop, op=dist.reduce_op.SUM) + dist.barrier() + should_stop = stop == self.trainer.world_size + return should_stop + + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + + def ddp_train(self, process_idx, model): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + + Returns: + + """ + # determine which process we are and world size + self.set_world_ranks(process_idx) + + # toggle prog bar + if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + self.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend} (TORCH_ELASTIC)') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # call sync_bn before .cuda(), configure_apex and configure_ddp + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) + + # move the model to the correct device + self.model_to_device(model, process_idx) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.setup_optimizers(model) + + # set model properties before going into wrapper + self.trainer.model_connector.copy_trainer_model_properties(model) + + # 16-bit + model = self.trainer.precision_connector.connect(model) + + # device ids change depending on the DDP setup + device_ids = self.get_device_ids() + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # set up training routine + self.trainer.train_loop.setup_training(model) + + # train or test + results = self.train_or_test() + + # clean up memory + torch.cuda.empty_cache() + + return results diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index 14f7c95e39..a6a1cdbfb5 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -29,6 +29,10 @@ class SLURMConnector: self.trainer.num_slurm_tasks = int(os.environ['SLURM_NTASKS']) self.trainer.is_slurm_managing_tasks = self.trainer.num_slurm_tasks == self.trainer.num_requested_gpus + # enable slurm cpu + if self.trainer.num_requested_gpus == 0: + self.trainer.is_slurm_managing_tasks = self.trainer.num_slurm_tasks == self.trainer.num_processes + # in interactive mode we don't manage tasks job_name = os.environ['SLURM_JOB_NAME'] if job_name == 'bash': diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py new file mode 100644 index 0000000000..2cd417fbc3 --- /dev/null +++ b/tests/backends/test_accelerator_connector.py @@ -0,0 +1,218 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest +import os +from tests.base.boring_model import BoringModel +from pytorch_lightning.callbacks import Callback +from pytorch_lightning import accelerators, Trainer +from unittest import mock + + +def test_accelerator_choice_cpu(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.CPUBackend) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + callbacks=[CB()] + ) + trainer.fit(model) + + +def test_accelerator_choice_ddp_cpu(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSpawnBackend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp_cpu', + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_accelerator_choice_ddp(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDPBackend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp', + gpus=1, + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_accelerator_choice_ddp_spawn(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDPSpawnBackend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp_spawn', + gpus=1, + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "SLURM_LOCALID": "0" +}) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_accelerator_choice_ddp_slurm(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDPSLURMBackend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp', + gpus=2, + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" +}) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_accelerator_choice_ddp2_slurm(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDP2Backend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp2', + gpus=2, + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "WORLD_SIZE": "2", + "LOCAL_RANK": "0", + "NODE_RANK": "0" +}) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_accelerator_choice_ddp_te(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDPTorchElasticBackend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp', + gpus=2, + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, { + "WORLD_SIZE": "1", + "LOCAL_RANK": "0", + "NODE_RANK": "0" +}) +@mock.patch('torch.cuda.device_count', return_value=0) +def test_accelerator_choice_ddp_cpu_te(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUTorchElasticBackend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp_cpu', + num_processes=1, + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, { + "SLURM_NTASKS": "1", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" +}) +@mock.patch('torch.cuda.device_count', return_value=0) +def test_accelerator_choice_ddp_cpu_slurm(tmpdir): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSLURMBackend) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp_cpu', + num_processes=1, + callbacks=[CB()] + ) + + with pytest.raises(SystemExit): + trainer.fit(model)