2022-03-01 23:27:30 +00:00
|
|
|
import random
|
2021-01-23 23:52:04 +00:00
|
|
|
|
2022-03-01 23:27:30 +00:00
|
|
|
import numpy as np
|
2021-01-12 04:30:27 +00:00
|
|
|
import pytest
|
2021-04-27 09:51:39 +00:00
|
|
|
import torch
|
2021-01-12 04:30:27 +00:00
|
|
|
|
2022-09-07 15:25:23 +00:00
|
|
|
from pytorch_lightning.utilities.seed import isolate_rng
|
2022-08-31 16:36:35 +00:00
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
2021-01-12 04:30:27 +00:00
|
|
|
|
|
|
|
|
2022-08-31 16:36:35 +00:00
|
|
|
@pytest.mark.parametrize("with_torch_cuda", [False, pytest.param(True, marks=RunIf(min_cuda_gpus=1))])
|
|
|
|
def test_isolate_rng(with_torch_cuda):
|
2022-03-01 23:27:30 +00:00
|
|
|
"""Test that the isolate_rng context manager isolates the random state from the outer scope."""
|
|
|
|
# torch
|
|
|
|
torch.rand(1)
|
|
|
|
with isolate_rng():
|
|
|
|
generated = [torch.rand(2) for _ in range(3)]
|
|
|
|
assert torch.equal(torch.rand(2), generated[0])
|
|
|
|
|
2022-08-26 05:26:00 +00:00
|
|
|
# torch.cuda
|
2022-08-31 16:36:35 +00:00
|
|
|
if with_torch_cuda:
|
2022-08-26 05:26:00 +00:00
|
|
|
torch.cuda.FloatTensor(1).normal_()
|
|
|
|
with isolate_rng():
|
|
|
|
generated = [torch.cuda.FloatTensor(2).normal_() for _ in range(3)]
|
|
|
|
assert torch.equal(torch.cuda.FloatTensor(2).normal_(), generated[0])
|
|
|
|
|
2022-03-01 23:27:30 +00:00
|
|
|
# numpy
|
|
|
|
np.random.rand(1)
|
|
|
|
with isolate_rng():
|
|
|
|
generated = [np.random.rand(2) for _ in range(3)]
|
|
|
|
assert np.equal(np.random.rand(2), generated[0]).all()
|
|
|
|
|
|
|
|
# python
|
|
|
|
random.random()
|
|
|
|
with isolate_rng():
|
|
|
|
generated = [random.random() for _ in range(3)]
|
|
|
|
assert random.random() == generated[0]
|