Merge branch 'feature/seed-sequence' into tests/worker-init-fn
This commit is contained in:
commit
35cf87e22f
|
@ -3,10 +3,11 @@ import os
|
|||
import random
|
||||
from random import getstate as python_get_rng_state
|
||||
from random import setstate as python_set_rng_state
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
from lightning.fabric.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn
|
||||
|
||||
|
@ -75,6 +76,9 @@ def reset_seed() -> None:
|
|||
seed_everything(int(seed), workers=bool(int(workers)))
|
||||
|
||||
|
||||
_NUMPY_AVAILABLE = RequirementCache("numpy")
|
||||
|
||||
|
||||
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
|
||||
r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
|
||||
``seed_everything(seed, workers=True)``.
|
||||
|
@ -91,15 +95,20 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
|
|||
log.debug(
|
||||
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
|
||||
)
|
||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
|
||||
# use 128 bits (4 x 32-bit words)
|
||||
np.random.seed(ss.generate_state(4))
|
||||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
|
||||
torch_ss, stdlib_ss = ss.spawn(2)
|
||||
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
|
||||
# use 128 bits expressed as an integer
|
||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
|
||||
random.seed(stdlib_seed)
|
||||
seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4)
|
||||
torch.manual_seed(seed_sequence[0])
|
||||
random.seed((seed_sequence[1] << 32) | seed_sequence[2])
|
||||
if _NUMPY_AVAILABLE:
|
||||
np.random.seed(seed_sequence[3] & 0xFFFFFFFF)
|
||||
|
||||
|
||||
def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]:
|
||||
combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
|
||||
seeds = []
|
||||
for _ in range(count):
|
||||
combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
|
||||
seeds.append(combined_seed)
|
||||
return seeds
|
||||
|
||||
|
||||
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
|
||||
|
|
Loading…
Reference in New Issue