From 22d826615f525f45510905a7105f6a2f40b1dc47 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 14 Jun 2021 16:39:50 +0300 Subject: [PATCH] Seed all workers when using DDP (#7942) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Seed all workers when using DDP * Fix to dataloader seeding * Make argument name explicit Co-authored-by: Carlos MocholĂ­ * Use f-strings when logging * Removed a redundant log message Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/seed.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c6d2de89c..76cc6a6486 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -214,6 +214,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942)) + + - Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945)) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 51547d5576..7c20b7d1b3 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -84,8 +84,9 @@ def reset_seed() -> None: If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. """ seed = os.environ.get("PL_GLOBAL_SEED", None) + workers = os.environ.get("PL_SEED_WORKERS", False) if seed is not None: - seed_everything(int(seed)) + seed_everything(int(seed), workers=bool(workers)) def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover @@ -100,6 +101,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # p process_seed = torch.initial_seed() # back out the base seed so we can use all the bits base_seed = process_seed - worker_id + log.debug( + f'Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}' + ) ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) # use 128 bits (4 x 32-bit words) np.random.seed(ss.generate_state(4))