diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index b274bce88f..da0d61845d 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -3,10 +3,11 @@ import os import random from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List import numpy as np import torch +from lightning_utilities.core.imports import RequirementCache from lightning.fabric.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn @@ -75,6 +76,9 @@ def reset_seed() -> None: seed_everything(int(seed), workers=bool(int(workers))) +_NUMPY_AVAILABLE = RequirementCache("numpy") + + def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with ``seed_everything(seed, workers=True)``. @@ -91,15 +95,20 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: log.debug( f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}" ) - ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) - # use 128 bits (4 x 32-bit words) - np.random.seed(ss.generate_state(4)) - # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module - torch_ss, stdlib_ss = ss.spawn(2) - torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) - # use 128 bits expressed as an integer - stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() - random.seed(stdlib_seed) + seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4) + torch.manual_seed(seed_sequence[0]) + random.seed((seed_sequence[1] << 32) | seed_sequence[2]) + if _NUMPY_AVAILABLE: + np.random.seed(seed_sequence[3] & 0xFFFFFFFF) + + +def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]: + combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank + seeds = [] + for _ in range(count): + combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1) + seeds.append(combined_seed) + return seeds def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: