From fe29c53ab5eb16758ccc448716e1c365da5c1beb Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 6 Aug 2020 02:42:09 +0200 Subject: [PATCH] add ddp sync for logging in result step (#2822) * add ddp sync for logging in result step * pep8 * pep8 * make ddp tests run also on cpu (except windowws) * create class instance in ddp test * revert automated formatting * pep8 --- pytorch_lightning/core/step_result.py | 33 +++++++++++++++++++++--- tests/core/test_results.py | 37 +++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 tests/core/test_results.py diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 253ccedabc..172930fd4a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -1,7 +1,9 @@ +import numbers from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any from torch import Tensor import torch from copy import copy +from pytorch_lightning.metrics.converters import _sync_ddp_if_available class Result(Dict): @@ -89,11 +91,18 @@ class Result(Dict): on_epoch: bool = True, reduce_fx: Callable = torch.mean, enable_graph: bool = False, + sync_ddp: bool = False, + sync_ddp_op: Union[Any, str] = 'mean', + sync_ddp_group: Optional[Any] = None ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() + # sync across ddp + if sync_ddp and isinstance(value, (torch.Tensor, numbers.Number)): + value = _sync_ddp_if_available(value, group=sync_ddp_group, reduce_op=sync_ddp_op) + if 'meta' not in self: self.__setitem__('meta', {}) @@ -338,6 +347,9 @@ class TrainResult(Result): on_epoch: bool = False, reduce_fx: Callable = torch.mean, enable_graph: bool = False, + sync_ddp: bool = False, + sync_ddp_op: Union[Any, str] = 'mean', + sync_ddp_group: Optional[Any] = None ): """ Log a key, value @@ -369,7 +381,8 @@ class TrainResult(Result): reduce_fx: Torch.mean by default enable_graph: if True, will not auto detach the graph """ - super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph) + super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, + sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) def log_dict( self, @@ -380,6 +393,9 @@ class TrainResult(Result): on_epoch: bool = True, reduce_fx: Callable = torch.mean, enable_graph: bool = False, + sync_ddp: bool = False, + sync_ddp_op: Union[Any, str] = 'mean', + sync_ddp_group: Optional[Any] = None ): """ Log a dictonary of values at once @@ -399,7 +415,8 @@ class TrainResult(Result): enable_graph: """ for k, v in dictionary.items(): - self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph) + self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, + sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) class EvalResult(Result): @@ -446,6 +463,9 @@ class EvalResult(Result): on_epoch: bool = True, reduce_fx: Callable = torch.mean, enable_graph: bool = False, + sync_ddp: bool = False, + sync_ddp_op: Union[Any, str] = 'mean', + sync_ddp_group: Optional[Any] = None ): """ Log a key, value @@ -476,7 +496,8 @@ class EvalResult(Result): reduce_fx: Torch.mean by default enable_graph: if True, will not auto detach the graph : """ - super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph) + super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, + sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) def log_dict( self, @@ -487,6 +508,9 @@ class EvalResult(Result): on_epoch: bool = True, reduce_fx: Callable = torch.mean, enable_graph: bool = False, + sync_ddp: bool = False, + sync_ddp_op: Union[Any, str] = 'mean', + sync_ddp_group: Optional[Any] = None ): """ Log a dictonary of values at once @@ -506,7 +530,8 @@ class EvalResult(Result): enable_graph: """ for k, v in dictionary.items(): - self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph) + self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, + sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op) def get_callback_metrics(self) -> dict: result = { diff --git a/tests/core/test_results.py b/tests/core/test_results.py new file mode 100644 index 0000000000..743a6d8915 --- /dev/null +++ b/tests/core/test_results.py @@ -0,0 +1,37 @@ +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult +import tests.base.develop_utils as tutils +import sys + + +def _setup_ddp(rank, worldsize): + import os + + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def _ddp_test_fn(rank, worldsize, result_cls: Result): + _setup_ddp(rank, worldsize) + tensor = torch.tensor([1.0]) + + res = result_cls() + res.log("test_tensor", tensor, sync_ddp=True, sync_ddp_op=torch.distributed.ReduceOp.SUM) + + assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" + + +@pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult]) +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_result_reduce_ddp(result_cls): + """Make sure result logging works with DDP""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(_ddp_test_fn, args=(worldsize, result_cls), nprocs=worldsize)