add ddp sync for logging in result step (#2822)

* add ddp sync for logging in result step

* pep8

* pep8

* make ddp tests run also on cpu (except windowws)

* create class instance in ddp test

* revert automated formatting

* pep8
This commit is contained in:
Justus Schock 2020-08-06 02:42:09 +02:00 committed by GitHub
parent b507c42c47
commit fe29c53ab5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 4 deletions

View File

@ -1,7 +1,9 @@
import numbers
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
from torch import Tensor
import torch
from copy import copy
from pytorch_lightning.metrics.converters import _sync_ddp_if_available
class Result(Dict):
@ -89,11 +91,18 @@ class Result(Dict):
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()
# sync across ddp
if sync_ddp and isinstance(value, (torch.Tensor, numbers.Number)):
value = _sync_ddp_if_available(value, group=sync_ddp_group, reduce_op=sync_ddp_op)
if 'meta' not in self:
self.__setitem__('meta', {})
@ -338,6 +347,9 @@ class TrainResult(Result):
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a key, value
@ -369,7 +381,8 @@ class TrainResult(Result):
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
def log_dict(
self,
@ -380,6 +393,9 @@ class TrainResult(Result):
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a dictonary of values at once
@ -399,7 +415,8 @@ class TrainResult(Result):
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
class EvalResult(Result):
@ -446,6 +463,9 @@ class EvalResult(Result):
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a key, value
@ -476,7 +496,8 @@ class EvalResult(Result):
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph :
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
def log_dict(
self,
@ -487,6 +508,9 @@ class EvalResult(Result):
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a dictonary of values at once
@ -506,7 +530,8 @@ class EvalResult(Result):
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)
def get_callback_metrics(self) -> dict:
result = {

View File

@ -0,0 +1,37 @@
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
import tests.base.develop_utils as tutils
import sys
def _setup_ddp(rank, worldsize):
import os
os.environ["MASTER_ADDR"] = "localhost"
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def _ddp_test_fn(rank, worldsize, result_cls: Result):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])
res = result_cls()
res.log("test_tensor", tensor, sync_ddp=True, sync_ddp_op=torch.distributed.ReduceOp.SUM)
assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
@pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult])
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_result_reduce_ddp(result_cls):
"""Make sure result logging works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize, result_cls), nprocs=worldsize)