Fix sync_dist for tpus (#6950)

This commit is contained in:
Kaushik B 2021-04-13 14:17:15 +05:30 committed by GitHub
parent 80c5293514
commit 1b3e4f9fb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 23 deletions

View File

@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))

View File

@ -106,7 +106,7 @@ class Accelerator(object):
self.precision_plugin.pre_dispatch()
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()

View File

@ -21,7 +21,7 @@ import torch
from torch import Tensor
from torchmetrics import Metric
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
class Result(Dict):
@ -105,10 +105,11 @@ class Result(Dict):
# sync across workers when using distributed training
sync_fn = sync_fn or sync_ddp_if_available
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if is_dist_initialized and isinstance(value, torch.Tensor):
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)

View File

@ -15,7 +15,7 @@ import io
import os
import re
import time
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
import torch
import torch.multiprocessing as mp
@ -41,7 +41,6 @@ else:
if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf
if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader
@ -278,4 +277,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
return xm.all_gather(tensor)

View File

@ -53,12 +53,10 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TPU_AVAILABLE,
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps

View File

@ -17,16 +17,14 @@ import os
import warnings
from functools import partial, wraps
from typing import Any, Optional, Union
from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
)
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
log = logging.getLogger(__name__)
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
@ -40,6 +38,9 @@ else:
WORLD = None
log = logging.getLogger(__name__)
def rank_zero_only(fn):
@wraps(fn)
@ -294,19 +295,13 @@ def register_ddp_comm_hook(
)
"""
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn(
"Not registering DDP comm hook. "
"To use communication hooks, please use pytorch>=1.8.0."
)
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
return
if ddp_comm_hook is None:
return
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn(
"Not applying DDP comm wrapper. "
"To use communication wrapper, please use pytorch>=1.9.0."
)
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
else:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
@ -318,3 +313,9 @@ def register_ddp_comm_hook(
state=ddp_comm_state,
hook=ddp_comm_hook,
)
def tpu_distributed() -> bool:
if _TPU_AVAILABLE:
return xm.xrt_world_size() > 1
return False

View File

@ -89,3 +89,7 @@ _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supporte
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

View File

@ -16,6 +16,7 @@ from argparse import ArgumentParser
from unittest import mock
import pytest
import torch
from torch.utils.data import DataLoader
import tests.helpers.pipelines as tpipes
@ -23,6 +24,7 @@ import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE
@ -416,3 +418,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_sync_dist():
"""Test tpu spawn sync dist operation """
def test_sync_dist(rank):
tensor = torch.tensor([1.0])
training_type_plugin = TPUSpawnPlugin()
res = Result()
res.log(
"test_tensor",
tensor,
sync_fn=training_type_plugin.reduce,
sync_dist=True,
sync_dist_op=torch.distributed.ReduceOp.SUM
)
assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"
xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')