120 lines
5.0 KiB
Python
120 lines
5.0 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Helper functions to help with reproducibility of models."""
|
|
|
|
import logging
|
|
import os
|
|
import random
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
|
|
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
|
|
sets the following environment variables:
|
|
|
|
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
|
|
- `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``.
|
|
|
|
Args:
|
|
seed: the integer value seed for global random state in Lightning.
|
|
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
|
|
or select it randomly.
|
|
workers: if set to ``True``, will properly configure all dataloaders passed to the
|
|
Trainer with a ``worker_init_fn``. If the user already provides such a function
|
|
for their dataloaders, setting this argument will have no influence. See also:
|
|
:func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`.
|
|
"""
|
|
max_seed_value = np.iinfo(np.uint32).max
|
|
min_seed_value = np.iinfo(np.uint32).min
|
|
|
|
if seed is None:
|
|
env_seed = os.environ.get("PL_GLOBAL_SEED")
|
|
if env_seed is None:
|
|
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
|
rank_zero_warn(f"No seed found, seed set to {seed}")
|
|
else:
|
|
try:
|
|
seed = int(env_seed)
|
|
except ValueError:
|
|
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
|
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
|
|
elif not isinstance(seed, int):
|
|
seed = int(seed)
|
|
|
|
if not (min_seed_value <= seed <= max_seed_value):
|
|
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
|
|
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
|
|
|
# using `log.info` instead of `rank_zero_info`,
|
|
# so users can verify the seed is properly set in distributed training.
|
|
log.info(f"Global seed set to {seed}")
|
|
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
|
|
|
|
return seed
|
|
|
|
|
|
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
|
|
return random.randint(min_seed_value, max_seed_value)
|
|
|
|
|
|
def reset_seed() -> None:
|
|
"""Reset the seed to the value that :func:`pytorch_lightning.utilities.seed.seed_everything` previously set.
|
|
|
|
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing.
|
|
"""
|
|
seed = os.environ.get("PL_GLOBAL_SEED", None)
|
|
workers = os.environ.get("PL_SEED_WORKERS", "0")
|
|
if seed is not None:
|
|
seed_everything(int(seed), workers=bool(int(workers)))
|
|
|
|
|
|
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
|
|
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed
|
|
with ``seed_everything(seed, workers=True)``.
|
|
|
|
See also the PyTorch documentation on
|
|
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
|
|
"""
|
|
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
|
|
global_rank = rank if rank is not None else rank_zero_only.rank
|
|
process_seed = torch.initial_seed()
|
|
# back out the base seed so we can use all the bits
|
|
base_seed = process_seed - worker_id
|
|
log.debug(
|
|
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
|
|
)
|
|
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
|
|
# use 128 bits (4 x 32-bit words)
|
|
np.random.seed(ss.generate_state(4))
|
|
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
|
|
torch_ss, stdlib_ss = ss.spawn(2)
|
|
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
|
|
# use 128 bits expressed as an integer
|
|
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
|
|
random.seed(stdlib_seed)
|