61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Helper functions to help with reproducibility of models. """
|
|
|
|
import os
|
|
import random
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from pytorch_lightning import _logger as log
|
|
|
|
|
|
def seed_everything(seed: Optional[int] = None) -> int:
|
|
"""Function that sets seed for pseudo-random number generators in:
|
|
pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable.
|
|
"""
|
|
max_seed_value = np.iinfo(np.uint32).max
|
|
min_seed_value = np.iinfo(np.uint32).min
|
|
|
|
try:
|
|
if seed is None:
|
|
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
|
else:
|
|
seed = int(seed)
|
|
except (TypeError, ValueError):
|
|
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
|
|
|
if (seed > max_seed_value) or (seed < min_seed_value):
|
|
log.warning(
|
|
f"{seed} is not in bounds, \
|
|
numpy accepts from {min_seed_value} to {max_seed_value}"
|
|
)
|
|
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
|
|
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
return seed
|
|
|
|
|
|
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
|
|
seed = random.randint(min_seed_value, max_seed_value)
|
|
log.warning(f"No correct seed found, seed set to {seed}")
|
|
return seed
|