add ddp sync for logging in result step (#2822)
* add ddp sync for logging in result step * pep8 * pep8 * make ddp tests run also on cpu (except windowws) * create class instance in ddp test * revert automated formatting * pep8
This commit is contained in:
parent
b507c42c47
commit
fe29c53ab5
|
@ -1,7 +1,9 @@
|
|||
import numbers
|
||||
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
|
||||
from torch import Tensor
|
||||
import torch
|
||||
from copy import copy
|
||||
from pytorch_lightning.metrics.converters import _sync_ddp_if_available
|
||||
|
||||
|
||||
class Result(Dict):
|
||||
|
@ -89,11 +91,18 @@ class Result(Dict):
|
|||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
enable_graph: bool = False,
|
||||
sync_ddp: bool = False,
|
||||
sync_ddp_op: Union[Any, str] = 'mean',
|
||||
sync_ddp_group: Optional[Any] = None
|
||||
):
|
||||
# no metrics should be logged with graphs
|
||||
if not enable_graph and isinstance(value, torch.Tensor):
|
||||
value = value.detach()
|
||||
|
||||
# sync across ddp
|
||||
if sync_ddp and isinstance(value, (torch.Tensor, numbers.Number)):
|
||||
value = _sync_ddp_if_available(value, group=sync_ddp_group, reduce_op=sync_ddp_op)
|
||||
|
||||
if 'meta' not in self:
|
||||
self.__setitem__('meta', {})
|
||||
|
||||
|
@ -338,6 +347,9 @@ class TrainResult(Result):
|
|||
on_epoch: bool = False,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
enable_graph: bool = False,
|
||||
sync_ddp: bool = False,
|
||||
sync_ddp_op: Union[Any, str] = 'mean',
|
||||
sync_ddp_group: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Log a key, value
|
||||
|
@ -369,7 +381,8 @@ class TrainResult(Result):
|
|||
reduce_fx: Torch.mean by default
|
||||
enable_graph: if True, will not auto detach the graph
|
||||
"""
|
||||
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
|
||||
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
|
||||
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
|
@ -380,6 +393,9 @@ class TrainResult(Result):
|
|||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
enable_graph: bool = False,
|
||||
sync_ddp: bool = False,
|
||||
sync_ddp_op: Union[Any, str] = 'mean',
|
||||
sync_ddp_group: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Log a dictonary of values at once
|
||||
|
@ -399,7 +415,8 @@ class TrainResult(Result):
|
|||
enable_graph:
|
||||
"""
|
||||
for k, v in dictionary.items():
|
||||
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
|
||||
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
|
||||
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
|
||||
|
||||
|
||||
class EvalResult(Result):
|
||||
|
@ -446,6 +463,9 @@ class EvalResult(Result):
|
|||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
enable_graph: bool = False,
|
||||
sync_ddp: bool = False,
|
||||
sync_ddp_op: Union[Any, str] = 'mean',
|
||||
sync_ddp_group: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Log a key, value
|
||||
|
@ -476,7 +496,8 @@ class EvalResult(Result):
|
|||
reduce_fx: Torch.mean by default
|
||||
enable_graph: if True, will not auto detach the graph :
|
||||
"""
|
||||
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
|
||||
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
|
||||
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
|
@ -487,6 +508,9 @@ class EvalResult(Result):
|
|||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
enable_graph: bool = False,
|
||||
sync_ddp: bool = False,
|
||||
sync_ddp_op: Union[Any, str] = 'mean',
|
||||
sync_ddp_group: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Log a dictonary of values at once
|
||||
|
@ -506,7 +530,8 @@ class EvalResult(Result):
|
|||
enable_graph:
|
||||
"""
|
||||
for k, v in dictionary.items():
|
||||
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
|
||||
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
|
||||
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
|
||||
|
||||
def get_callback_metrics(self) -> dict:
|
||||
result = {
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
|
||||
import tests.base.develop_utils as tutils
|
||||
import sys
|
||||
|
||||
|
||||
def _setup_ddp(rank, worldsize):
|
||||
import os
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||||
|
||||
|
||||
def _ddp_test_fn(rank, worldsize, result_cls: Result):
|
||||
_setup_ddp(rank, worldsize)
|
||||
tensor = torch.tensor([1.0])
|
||||
|
||||
res = result_cls()
|
||||
res.log("test_tensor", tensor, sync_ddp=True, sync_ddp_op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult])
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
def test_result_reduce_ddp(result_cls):
|
||||
"""Make sure result logging works with DDP"""
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(_ddp_test_fn, args=(worldsize, result_cls), nprocs=worldsize)
|
Loading…
Reference in New Issue