import sys import numpy as np import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import tests.base.develop_utils as tutils from pytorch_lightning.metrics.converters import ( _apply_to_inputs, _apply_to_outputs, convert_to_tensor, convert_to_numpy, _numpy_metric_conversion, _tensor_metric_conversion, sync_ddp_if_available, gather_all_tensors_if_available, tensor_metric, numpy_metric ) def test_apply_to_inputs(): def apply_fn(inputs, factor): if isinstance(inputs, (float, int)): return inputs * factor elif isinstance(inputs, dict): return {k: apply_fn(v, factor) for k, v in inputs.items()} elif isinstance(inputs, (tuple, list)): return [apply_fn(x, factor) for x in inputs] @_apply_to_inputs(apply_fn, factor=2.) def test_fn(*args, **kwargs): return args, kwargs for args in [[], [1., 2.]]: for kwargs in [{}, {'a': 1., 'b': 2.}]: result_args, result_kwargs = test_fn(*args, **kwargs) assert isinstance(result_args, (list, tuple)) assert isinstance(result_kwargs, dict) assert len(result_args) == len(args) assert len(result_kwargs) == len(kwargs) assert all([k in result_kwargs for k in kwargs.keys()]) for arg, result_arg in zip(args, result_args): assert arg * 2. == result_arg for key in kwargs.keys(): arg = kwargs[key] result_arg = result_kwargs[key] assert arg * 2. == result_arg def test_apply_to_outputs(): def apply_fn(inputs, additional_str): return str(inputs) + additional_str @_apply_to_outputs(apply_fn, additional_str='_str') def test_fn(*args, **kwargs): return 'dummy' assert test_fn() == 'dummy_str' def test_convert_to_tensor(): for test_item in [1., np.array([1.])]: result_tensor = convert_to_tensor(test_item) assert isinstance(result_tensor, torch.Tensor) assert result_tensor.item() == 1. def test_convert_to_numpy(): for test_item in [1., torch.tensor([1.])]: result = convert_to_numpy(test_item) assert isinstance(result, np.ndarray) assert result.item() == 1. def test_numpy_metric_conversion(): @_numpy_metric_conversion def numpy_test_metric(*args, **kwargs): for arg in args: assert isinstance(arg, np.ndarray) for v in kwargs.values(): assert isinstance(v, np.ndarray) return 5. result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) assert isinstance(result, torch.Tensor) assert result.item() == 5. def test_tensor_metric_conversion(): @_tensor_metric_conversion def tensor_test_metric(*args, **kwargs): for arg in args: assert isinstance(arg, torch.Tensor) for v in kwargs.values(): assert isinstance(v, torch.Tensor) return 5. result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) assert isinstance(result, torch.Tensor) assert result.item() == 5. def _setup_ddp(rank, worldsize): import os os.environ['MASTER_ADDR'] = 'localhost' # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=worldsize) def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False): _setup_ddp(rank, worldsize) if add_offset: tensor = torch.tensor([float(rank)]) else: tensor = torch.tensor([1.], ) if reduction_mean: reduced_tensor = sync_ddp_if_available(tensor, reduce_op='avg') manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size() assert reduced_tensor.item() == manual_reduction else: reduced_tensor = sync_ddp_if_available(tensor) assert reduced_tensor.item() == dist.get_world_size(), \ 'Sync-Reduce does not work properly with DDP and Tensors' def _ddp_test_gather_all_tensors(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([rank]) gather_tensors = gather_all_tensors_if_available(tensor) mannual_tensors = [torch.tensor([i]) for i in range(worldsize)] for t1, t2 in zip(gather_tensors, mannual_tensors): assert(t1.equal(t2)) @pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") def test_sync_reduce_ddp(): """Make sure sync-reduce works with DDP""" tutils.reset_seed() tutils.set_random_master_port() worldsize = 2 mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize) @pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") def test_sync_reduce_ddp_mean(): """Make sure sync-reduce works with DDP""" tutils.reset_seed() tutils.set_random_master_port() worldsize = 2 mp.spawn(_ddp_test_fn, args=(worldsize, True, True), nprocs=worldsize) def test_sync_reduce_simple(): """Make sure sync-reduce works without DDP""" tensor = torch.tensor([1.], device='cpu') reduced_tensor = sync_ddp_if_available(tensor) assert torch.allclose(tensor, reduced_tensor), \ 'Sync-Reduce does not work properly without DDP and Tensors' @pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") def test_gather_all_tensors_ddp(): """Make sure gather_all_tensors works with DDP""" tutils.reset_seed() tutils.set_random_master_port() worldsize = 2 mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize) def _test_tensor_metric(is_ddp: bool): @tensor_metric() def tensor_test_metric(*args, **kwargs): for arg in args: assert isinstance(arg, torch.Tensor) for v in kwargs.values(): assert isinstance(v, torch.Tensor) return 5. if is_ddp: factor = dist.get_world_size() else: factor = 1. result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) assert isinstance(result, torch.Tensor) assert result.item() == 5. * factor def _ddp_test_tensor_metric(rank, worldsize): _setup_ddp(rank, worldsize) _test_tensor_metric(True) @pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") def test_tensor_metric_ddp(): tutils.reset_seed() tutils.set_random_master_port() world_size = 2 mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size) # dist.destroy_process_group() def test_tensor_metric_simple(): _test_tensor_metric(False) def _test_numpy_metric(is_ddp: bool): @numpy_metric() def numpy_test_metric(*args, **kwargs): for arg in args: assert isinstance(arg, np.ndarray) for v in kwargs.values(): assert isinstance(v, np.ndarray) return 5. if is_ddp: factor = dist.get_world_size() else: factor = 1. result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) assert isinstance(result, torch.Tensor) assert result.item() == 5. * factor def _ddp_test_numpy_metric(rank, worldsize): _setup_ddp(rank, worldsize) _test_numpy_metric(True) @pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") def test_numpy_metric_ddp(): tutils.reset_seed() tutils.set_random_master_port() world_size = 2 mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size) # dist.destroy_process_group() def test_numpy_metric_simple(): _test_numpy_metric(False)