Fix deterministic behavior in ddp_spawn (#3573)
* docs * set env variable * fix * changelog
This commit is contained in:
parent
9acee67c31
commit
a71d62d840
|
@ -68,6 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed dataloader shuffling not getting turned off with `overfit_batches > 0` and `distributed_backend = "ddp"` ([#3534](https://github.com/PyTorchLightning/pytorch-lightning/pull/3534))
|
||||
|
||||
- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335))
|
||||
|
||||
## [0.9.0] - YYYY-MM-DD
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
|
||||
|
@ -22,6 +22,7 @@ import torch.distributed as dist
|
|||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
try:
|
||||
from hydra.utils import to_absolute_path, get_original_cwd
|
||||
|
@ -97,6 +98,11 @@ class DDPBase(Accelerator):
|
|||
Returns:
|
||||
|
||||
"""
|
||||
seed = os.environ.get("PL_GLOBAL_SEED")
|
||||
if seed is not None:
|
||||
seed_everything(int(seed))
|
||||
|
||||
|
||||
# offset the process id if requested
|
||||
process_idx = process_idx + proc_offset
|
||||
|
||||
|
|
|
@ -25,17 +25,24 @@ from pytorch_lightning import _logger as log
|
|||
|
||||
|
||||
def seed_everything(seed: Optional[int] = None) -> int:
|
||||
"""Function that sets seed for pseudo-random number generators in:
|
||||
pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable.
|
||||
"""
|
||||
Function that sets seed for pseudo-random number generators in:
|
||||
pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable.
|
||||
In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to
|
||||
spawned subprocesses (e.g. ddp_spawn backend).
|
||||
|
||||
Args:
|
||||
seed: the integer value seed for global random state in Lightning.
|
||||
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
|
||||
or select it randomly.
|
||||
"""
|
||||
max_seed_value = np.iinfo(np.uint32).max
|
||||
min_seed_value = np.iinfo(np.uint32).min
|
||||
|
||||
try:
|
||||
if seed is None:
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
else:
|
||||
seed = int(seed)
|
||||
seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value))
|
||||
seed = int(seed)
|
||||
except (TypeError, ValueError):
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
|
||||
|
@ -47,6 +54,7 @@ def seed_everything(seed: Optional[int] = None) -> int:
|
|||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
|
Loading…
Reference in New Issue