diff --git a/CHANGELOG.md b/CHANGELOG.md index df7e1bdd81..a0d229887b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915)) +- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950)) + + - Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index fea4ae725b..972ec275a5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -106,7 +106,7 @@ class Accelerator(object): self.precision_plugin.pre_dispatch() def post_dispatch(self, trainer: 'pl.Trainer') -> None: - """Hook to do something before the training/evaluation/prediction starts.""" + """Hook to do something after the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index eb0f26cec2..7a193662b5 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -21,7 +21,7 @@ import torch from torch import Tensor from torchmetrics import Metric -from pytorch_lightning.utilities.distributed import sync_ddp_if_available +from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed class Result(Dict): @@ -105,10 +105,11 @@ class Result(Dict): # sync across workers when using distributed training sync_fn = sync_fn or sync_ddp_if_available + if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() # TODO: Find a way to make the reduction only once, so we don't need to clone. - if is_dist_initialized and isinstance(value, torch.Tensor): + if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor): value = value.clone() else: value = torch.tensor(value, device=device, dtype=torch.float) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index b072a29c7f..5bcfd093ae 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,7 +15,7 @@ import io import os import re import time -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import torch import torch.multiprocessing as mp @@ -41,7 +41,6 @@ else: if _OMEGACONF_AVAILABLE: from omegaconf import DictConfig, ListConfig, OmegaConf - if TYPE_CHECKING: from torch.nn import Module from torch.utils.data import DataLoader @@ -278,4 +277,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin): Return: A tensor of shape (world_size, batch, ...) """ - return xm.all_gather(tensor.unsqueeze(0)) + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return xm.all_gather(tensor) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 398e3782be..3c1108b535 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -53,12 +53,10 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, _TORCHVISION_AVAILABLE, + _TPU_AVAILABLE, _XLA_AVAILABLE, ) from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 -from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 - -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 018d83a93a..a54d00a983 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -17,16 +17,14 @@ import os import warnings from functools import partial, wraps from typing import Any, Optional, Union -from pytorch_lightning.utilities.imports import ( - _TORCH_GREATER_EQUAL_1_8, - _TORCH_GREATER_EQUAL_1_9, -) import torch - from torch.nn.parallel.distributed import DistributedDataParallel -log = logging.getLogger(__name__) +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm if torch.distributed.is_available(): from torch.distributed import group, ReduceOp @@ -40,6 +38,9 @@ else: WORLD = None +log = logging.getLogger(__name__) + + def rank_zero_only(fn): @wraps(fn) @@ -294,19 +295,13 @@ def register_ddp_comm_hook( ) """ if not _TORCH_GREATER_EQUAL_1_8: - rank_zero_warn( - "Not registering DDP comm hook. " - "To use communication hooks, please use pytorch>=1.8.0." - ) + rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.") return if ddp_comm_hook is None: return if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: - rank_zero_warn( - "Not applying DDP comm wrapper. " - "To use communication wrapper, please use pytorch>=1.9.0." - ) + rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.") else: rank_zero_info( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." @@ -318,3 +313,9 @@ def register_ddp_comm_hook( state=ddp_comm_state, hook=ddp_comm_hook, ) + + +def tpu_distributed() -> bool: + if _TPU_AVAILABLE: + return xm.xrt_world_size() > 1 + return False diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index c0f8d852a9..740fe73f6c 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -89,3 +89,7 @@ _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supporte _TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available('torchvision') _XLA_AVAILABLE = _module_available("torch_xla") + +from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 + +_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 6409f2ef4b..e623c48088 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -16,6 +16,7 @@ from argparse import ArgumentParser from unittest import mock import pytest +import torch from torch.utils.data import DataLoader import tests.helpers.pipelines as tpipes @@ -23,6 +24,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE @@ -416,3 +418,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir): trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_sync_dist(): + """Test tpu spawn sync dist operation """ + + def test_sync_dist(rank): + tensor = torch.tensor([1.0]) + training_type_plugin = TPUSpawnPlugin() + + res = Result() + res.log( + "test_tensor", + tensor, + sync_fn=training_type_plugin.reduce, + sync_dist=True, + sync_dist_op=torch.distributed.ReduceOp.SUM + ) + + assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors" + + xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')