diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 4f7124c07e..1282da0568 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -13,7 +13,6 @@ # limitations under the License import os -import re import torch @@ -22,11 +21,7 @@ from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result -from pytorch_lightning.accelerators.base_backend import Accelerator -import torch.distributed as torch_distrib -import torch.distributed as dist -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.accelerators.ddp_base_backend import DDPBase try: from hydra.utils import to_absolute_path, get_original_cwd @@ -42,7 +37,7 @@ except ImportError: amp = None -class DDP2Backend(Accelerator): +class DDP2Backend(DDPBase): def __init__(self, trainer): super().__init__(trainer) @@ -170,22 +165,6 @@ class DDP2Backend(Accelerator): # clean up memory torch.cuda.empty_cache() - def training_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def training_step_end(self, output): if isinstance(output, Result): output.dp_reduce() @@ -200,35 +179,3 @@ class DDP2Backend(Accelerator): if isinstance(output, Result): output.dp_reduce() return output - - def barrier(self, name: str = None): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']: - return - - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) - mp_queue.put(last_path) diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index a675fa4c9f..4b63b69c6b 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -13,7 +13,6 @@ # limitations under the License import os -import re import subprocess import sys from os.path import abspath @@ -25,12 +24,8 @@ import torch from pytorch_lightning import _logger as log from pytorch_lightning.utilities import AMPType -from pytorch_lightning.accelerators.base_backend import Accelerator -import torch.distributed as torch_distrib -import torch.distributed as dist from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.accelerators.ddp_base_backend import DDPBase try: from hydra.utils import to_absolute_path, get_original_cwd @@ -46,7 +41,7 @@ except ImportError: amp = None -class DDPBackend(Accelerator): +class DDPBackend(DDPBase): def __init__(self, trainer, mode: str = 'ddp'): super().__init__(trainer) @@ -257,57 +252,9 @@ class DDPBackend(Accelerator): if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']: return results - def training_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( "You tried to run `.fit` or `.test` multiple times in the same script." " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." ) - - def barrier(self, name: str = None): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']: - return - - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) - mp_queue.put(last_path) diff --git a/pytorch_lightning/accelerators/ddp_base_backend.py b/pytorch_lightning/accelerators/ddp_base_backend.py new file mode 100644 index 0000000000..ed00a722cc --- /dev/null +++ b/pytorch_lightning/accelerators/ddp_base_backend.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import re +import torch + +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.accelerators.base_backend import Accelerator +import torch.distributed as torch_distrib +import torch.distributed as dist +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.distributed import rank_zero_warn + +try: + from hydra.utils import to_absolute_path, get_original_cwd + from hydra.core.hydra_config import HydraConfig +except ImportError: + HYDRA_AVAILABLE = False +else: + HYDRA_AVAILABLE = True + +try: + from apex import amp +except ImportError: + amp = None + + +class DDPBase(Accelerator): + + def __init__(self, trainer): + super().__init__(trainer) + + def training_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model(*args) + else: + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + + def barrier(self, name: str = None): + torch_distrib.barrier() + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) + dist.all_reduce(stop, op=dist.reduce_op.SUM) + dist.barrier() + should_stop = stop == self.trainer.world_size + return should_stop + + def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): + if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']: + return + + # track the best model path + best_model_path = None + if self.trainer.checkpoint_callback is not None: + best_model_path = self.trainer.checkpoint_callback.best_model_path + + if self.trainer.global_rank == 0 and mp_queue is not None: + rank_zero_warn('cleaning up ddp environment...') + # todo, pass complete checkpoint as state dictionary + mp_queue.put(best_model_path) + mp_queue.put(results) + + # save the last weights + last_path = None + if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) + atomic_save(model.state_dict(), last_path) + mp_queue.put(last_path) diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index 82375aa935..27c02ca46c 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License import os -import re import torch import torch.multiprocessing as mp @@ -20,11 +19,7 @@ import torch.multiprocessing as mp from pytorch_lightning import _logger as log from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port -from pytorch_lightning.accelerators.base_backend import Accelerator -import torch.distributed as torch_distrib -import torch.distributed as dist -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.accelerators.ddp_base_backend import DDPBase try: from apex import amp @@ -32,7 +27,7 @@ except ImportError: amp = None -class DDPSpawnBackend(Accelerator): +class DDPSpawnBackend(DDPBase): def __init__(self, trainer, nprocs): super().__init__(trainer) @@ -180,51 +175,3 @@ class DDPSpawnBackend(Accelerator): # clean up memory torch.cuda.empty_cache() - - def training_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - - def barrier(self, name: str = None): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']: - return - - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) - mp_queue.put(last_path) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6a0b2e2602..957f6f68f9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -876,7 +876,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod global_rank: The global process idx. world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus). is_slurm_managing_tasks: is cluster managed by SLURM. - """ if is_slurm_managing_tasks: self._init_slurm_connection()