Fix sync_dist for tpus (#6950)
This commit is contained in:
parent
80c5293514
commit
1b3e4f9fb9
|
@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
|
||||
|
||||
|
||||
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
|
||||
|
||||
|
||||
- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))
|
||||
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ class Accelerator(object):
|
|||
self.precision_plugin.pre_dispatch()
|
||||
|
||||
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
"""Hook to do something after the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.post_dispatch()
|
||||
self.precision_plugin.post_dispatch()
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import torch
|
|||
from torch import Tensor
|
||||
from torchmetrics import Metric
|
||||
|
||||
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
|
||||
|
||||
|
||||
class Result(Dict):
|
||||
|
@ -105,10 +105,11 @@ class Result(Dict):
|
|||
|
||||
# sync across workers when using distributed training
|
||||
sync_fn = sync_fn or sync_ddp_if_available
|
||||
|
||||
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
|
||||
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
# TODO: Find a way to make the reduction only once, so we don't need to clone.
|
||||
if is_dist_initialized and isinstance(value, torch.Tensor):
|
||||
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
|
||||
value = value.clone()
|
||||
else:
|
||||
value = torch.tensor(value, device=device, dtype=torch.float)
|
||||
|
|
|
@ -15,7 +15,7 @@ import io
|
|||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -41,7 +41,6 @@ else:
|
|||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -278,4 +277,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
return xm.all_gather(tensor.unsqueeze(0))
|
||||
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
return xm.all_gather(tensor)
|
||||
|
|
|
@ -53,12 +53,10 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
_TORCH_QUANTIZE_AVAILABLE,
|
||||
_TORCHTEXT_AVAILABLE,
|
||||
_TORCHVISION_AVAILABLE,
|
||||
_TPU_AVAILABLE,
|
||||
_XLA_AVAILABLE,
|
||||
)
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
|
||||
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401
|
||||
|
||||
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
||||
|
||||
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
|
||||
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
|
||||
|
|
|
@ -17,16 +17,14 @@ import os
|
|||
import warnings
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Optional, Union
|
||||
from pytorch_lightning.utilities.imports import (
|
||||
_TORCH_GREATER_EQUAL_1_8,
|
||||
_TORCH_GREATER_EQUAL_1_9,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import group, ReduceOp
|
||||
|
@ -40,6 +38,9 @@ else:
|
|||
WORLD = None
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rank_zero_only(fn):
|
||||
|
||||
@wraps(fn)
|
||||
|
@ -294,19 +295,13 @@ def register_ddp_comm_hook(
|
|||
)
|
||||
"""
|
||||
if not _TORCH_GREATER_EQUAL_1_8:
|
||||
rank_zero_warn(
|
||||
"Not registering DDP comm hook. "
|
||||
"To use communication hooks, please use pytorch>=1.8.0."
|
||||
)
|
||||
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
|
||||
return
|
||||
if ddp_comm_hook is None:
|
||||
return
|
||||
if ddp_comm_wrapper is not None:
|
||||
if not _TORCH_GREATER_EQUAL_1_9:
|
||||
rank_zero_warn(
|
||||
"Not applying DDP comm wrapper. "
|
||||
"To use communication wrapper, please use pytorch>=1.9.0."
|
||||
)
|
||||
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
|
||||
else:
|
||||
rank_zero_info(
|
||||
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
|
||||
|
@ -318,3 +313,9 @@ def register_ddp_comm_hook(
|
|||
state=ddp_comm_state,
|
||||
hook=ddp_comm_hook,
|
||||
)
|
||||
|
||||
|
||||
def tpu_distributed() -> bool:
|
||||
if _TPU_AVAILABLE:
|
||||
return xm.xrt_world_size() > 1
|
||||
return False
|
||||
|
|
|
@ -89,3 +89,7 @@ _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supporte
|
|||
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
|
||||
_TORCHVISION_AVAILABLE = _module_available('torchvision')
|
||||
_XLA_AVAILABLE = _module_available("torch_xla")
|
||||
|
||||
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
|
||||
|
||||
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
||||
|
|
|
@ -16,6 +16,7 @@ from argparse import ArgumentParser
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import tests.helpers.pipelines as tpipes
|
||||
|
@ -23,6 +24,7 @@ import tests.helpers.utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators import TPUAccelerator
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.plugins import TPUSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
|
@ -416,3 +418,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
|
|||
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
|
||||
trainer.fit(model)
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
@pl_multi_process_test
|
||||
def test_tpu_sync_dist():
|
||||
"""Test tpu spawn sync dist operation """
|
||||
|
||||
def test_sync_dist(rank):
|
||||
tensor = torch.tensor([1.0])
|
||||
training_type_plugin = TPUSpawnPlugin()
|
||||
|
||||
res = Result()
|
||||
res.log(
|
||||
"test_tensor",
|
||||
tensor,
|
||||
sync_fn=training_type_plugin.reduce,
|
||||
sync_dist=True,
|
||||
sync_dist_op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
|
||||
assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"
|
||||
|
||||
xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')
|
||||
|
|
Loading…
Reference in New Issue