lightning/tests/tests_pytorch/utilities/test_seed.py

38 lines
1.1 KiB
Python
Raw Normal View History

import random
2021-01-23 23:52:04 +00:00
import numpy as np
import pytest
import torch
from pytorch_lightning.utilities.seed import isolate_rng
from tests_pytorch.helpers.runif import RunIf
@pytest.mark.parametrize("with_torch_cuda", [False, pytest.param(True, marks=RunIf(min_cuda_gpus=1))])
def test_isolate_rng(with_torch_cuda):
"""Test that the isolate_rng context manager isolates the random state from the outer scope."""
# torch
torch.rand(1)
with isolate_rng():
generated = [torch.rand(2) for _ in range(3)]
assert torch.equal(torch.rand(2), generated[0])
# torch.cuda
if with_torch_cuda:
torch.cuda.FloatTensor(1).normal_()
with isolate_rng():
generated = [torch.cuda.FloatTensor(2).normal_() for _ in range(3)]
assert torch.equal(torch.cuda.FloatTensor(2).normal_(), generated[0])
# numpy
np.random.rand(1)
with isolate_rng():
generated = [np.random.rand(2) for _ in range(3)]
assert np.equal(np.random.rand(2), generated[0]).all()
# python
random.random()
with isolate_rng():
generated = [random.random() for _ in range(3)]
assert random.random() == generated[0]