Seed NumPy using `np.random.SeedSequence()` in `pl_worker_init_function()` to robustly seed NumPy-dependent dataloader workers (#20369)

* Update seed.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update seed.py

* Update seed.py

* Update seed.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
Alex Morehead 2024-11-25 16:40:33 -06:00 committed by GitHub
parent 1f4a77c448
commit 29c0396321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 1 deletions

View File

@ -104,7 +104,10 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
if _NUMPY_AVAILABLE:
import numpy as np
np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
np_rng_seed = ss.generate_state(4)
np.random.seed(np_rng_seed)
def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]: