Merge branch 'feature/seed-sequence' into tests/worker-init-fn

This commit is contained in:
awaelchli 2024-07-12 22:00:14 +02:00
commit 35cf87e22f
1 changed files with 19 additions and 10 deletions

View File

@ -3,10 +3,11 @@ import os
import random
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List
import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning.fabric.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn
@ -75,6 +76,9 @@ def reset_seed() -> None:
seed_everything(int(seed), workers=bool(int(workers)))
_NUMPY_AVAILABLE = RequirementCache("numpy")
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
``seed_everything(seed, workers=True)``.
@ -91,15 +95,20 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
log.debug(
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
)
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
torch_ss, stdlib_ss = ss.spawn(2)
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)
seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4)
torch.manual_seed(seed_sequence[0])
random.seed((seed_sequence[1] << 32) | seed_sequence[2])
if _NUMPY_AVAILABLE:
np.random.seed(seed_sequence[3] & 0xFFFFFFFF)
def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]:
combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
seeds = []
for _ in range(count):
combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
seeds.append(combined_seed)
return seeds
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: