diff --git a/docs/source/index.rst b/docs/source/index.rst index 32f4ab3dbb..ec85f7e043 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -141,3 +141,4 @@ Indices and tables api/pytorch_lightning.utilities api/pytorch_lightning.tuner api/pytorch_lightning.plugins + api/pytorch_lightning.distributed diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py index 0afedf14ab..60ea76aaa7 100644 --- a/pytorch_lightning/accelerators/base_backend.py +++ b/pytorch_lightning/accelerators/base_backend.py @@ -7,6 +7,7 @@ import torch from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import AttributeDict try: from apex import amp @@ -21,6 +22,7 @@ class Accelerator(object): def __init__(self, trainer): self.trainer = trainer + self.dist = AttributeDict(rank=0, device=None) def setup(self, model): pass @@ -31,6 +33,9 @@ class Accelerator(object): def barrier(self, name: str = None): pass + def broadcast(self, obj, src=0): + return obj + def train_or_test(self): if self.trainer.testing: results = self.trainer.run_test() diff --git a/pytorch_lightning/accelerators/ddp_base_backend.py b/pytorch_lightning/accelerators/ddp_base_backend.py index 526b06d59d..35dc89abe6 100644 --- a/pytorch_lightning/accelerators/ddp_base_backend.py +++ b/pytorch_lightning/accelerators/ddp_base_backend.py @@ -24,6 +24,7 @@ from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.distributed.dist import LightningDistributed try: from hydra.core.hydra_config import HydraConfig @@ -38,6 +39,7 @@ class DDPBase(Accelerator): def __init__(self, trainer): super().__init__(trainer) + self.dist = LightningDistributed() def training_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: @@ -177,6 +179,9 @@ class DDPBase(Accelerator): if self.trainer.global_rank == 0: return results + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + def set_world_ranks(self, process_idx): raise NotImplementedError('to create a ddp backend, please implement set_world_ranks') diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py index 58577be618..2f0c6d29c7 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py @@ -25,6 +25,7 @@ from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.distributed import find_free_network_port +from pytorch_lightning.distributed.dist import LightningDistributed try: from hydra.core.hydra_config import HydraConfig @@ -41,6 +42,7 @@ class DDPCPUSpawnBackend(Accelerator): super().__init__(trainer) self.mp_queue = None self.nprocs = nprocs + self.dist = LightningDistributed() def setup(self, model): os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) @@ -174,6 +176,9 @@ class DDPCPUSpawnBackend(Accelerator): def barrier(self, name: str = None): torch_distrib.barrier() + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.SUM) diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index 87d00bcd8b..0bf7e18bc2 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -16,7 +16,7 @@ import torch from torch import optim from pytorch_lightning.accelerators.base_backend import Accelerator -from pytorch_lightning.core import LightningModule +from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import LightningDataParallel from pytorch_lightning.utilities import AMPType @@ -28,6 +28,7 @@ class DataParallelBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) self.model_autocast_original_forward = None + self.dist = LightningDistributed() def setup(self, model): # call setup after the ddp process has connected diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index d3c6c59160..ea1d57ccea 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.utilities import AMPType +from pytorch_lightning.distributed.dist import LightningDistributed class GPUBackend(Accelerator): @@ -23,6 +24,7 @@ class GPUBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) + self.dist = LightningDistributed() def setup(self, model): diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index cfdf80fa2b..2fcf75c215 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -158,3 +158,7 @@ class HorovodBackend(Accelerator): def barrier(self, name: str = None): hvd.join() + + def broadcast(self, obj, src=0): + obj = hvd.broadcast_object(obj, src) + return obj diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d29b7e4ce8..4357163f46 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -366,7 +366,6 @@ class ModelCheckpoint(Callback): ckpt_name = f"{filename}.ckpt" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - @rank_zero_only def __resolve_ckpt_dir(self, trainer, pl_module): """ Determines model checkpoint save directory at runtime. References attributes from the @@ -398,8 +397,11 @@ class ModelCheckpoint(Callback): if isinstance(trainer.logger.version, str) else f"version_{trainer.logger.version}" ) + + version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name)) + ckpt_path = os.path.join( - save_dir, trainer.logger.name, version, "checkpoints" + save_dir, name, version, "checkpoints" ) else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") diff --git a/pytorch_lightning/distributed/__init__.py b/pytorch_lightning/distributed/__init__.py new file mode 100644 index 0000000000..15540f7a7d --- /dev/null +++ b/pytorch_lightning/distributed/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.distributed.dist import LightningDistributed diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py new file mode 100644 index 0000000000..a3f0378f0a --- /dev/null +++ b/pytorch_lightning/distributed/dist.py @@ -0,0 +1,36 @@ +import io +import torch +from typing import Any +from torch import distributed as torch_distrib + + +class LightningDistributed: + + def __init__(self, rank=None, device=None): + self.rank = rank + self.device = device + + def broadcast(self, obj: Any): + if self.rank == 0: + self._emit(obj) + else: + obj = self._receive() + return obj + + def _emit(self, obj): + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.tensor([len(data)]).long().to(self.device) + length_tensor = torch_distrib.broadcast(length_tensor, src=0) + data_tensor = torch.ByteTensor(data).to(self.device) + data_tensor = torch_distrib.broadcast(data_tensor, src=0) + + def _receive(self): + length_tensor = torch.tensor([0]).long().to(self.device) + torch_distrib.broadcast(length_tensor, src=0) + data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) + torch_distrib.broadcast(data_tensor, src=0) + buffer = io.BytesIO(data_tensor.cpu().numpy()) + obj = torch.load(buffer) + return obj diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index ff8aab3743..26572feb5e 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -69,7 +69,7 @@ class TrainerLoggingMixin(ABC): m = inspect.cleandoc( f"""The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0 Please use self.log(...) inside the lightningModule instead. - + # log on a step or aggregate epoch metric to the logger and/or progress bar # (inside LightningModule) self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 428e081a5c..e670c01f04 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -108,6 +108,9 @@ class TrainLoop: if self.trainer.data_parallel: ref_model = model.module + self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank + self.trainer.accelerator_backend.dist.device = ref_model.device + # give model convenience properties ref_model.trainer = self.trainer diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 4fdd03e1e6..fd39490318 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -437,7 +437,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): trainer.fit(model) path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") - path_last = str(tmpdir / f"last.ckpt") + path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path ckpt_last_epoch = torch.load(path_last_epoch) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 4e4476955a..5325ca828e 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -235,6 +235,7 @@ def test_dm_checkpoint_save(tmpdir): assert dm.__class__.__name__ in checkpoint assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ + def test_test_loop_only(tmpdir): reset_seed()