Support deterministic="warn" in Trainer for Pytorch 1.11+ (#12588)
Co-authored-by: carmocca <carlossmocholi@gmail.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
a41486245a
commit
6490996b39
|
@ -13,17 +13,10 @@ concurrency:
|
|||
jobs:
|
||||
mypy:
|
||||
runs-on: ubuntu-20.04
|
||||
#strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# include:
|
||||
# - {python-version: "3.8", pytorch-version: "1.8"}
|
||||
# - {python-version: "3.9", pytorch-version: "1.10"}
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
python-version: 3.9
|
||||
|
||||
# Note: This uses an internal pip API and may not always work
|
||||
|
@ -37,15 +30,10 @@ jobs:
|
|||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
env:
|
||||
# TORCH_VERSION: ${{ matrix.pytorch-version }}
|
||||
TORCH_VERSION: "1.10"
|
||||
run: |
|
||||
pip install "torch==$TORCH_VERSION" --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
# adjust versions according installed Torch version
|
||||
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
python ./requirements/adjust-versions.py requirements/extra.txt
|
||||
python ./requirements/adjust-versions.py requirements/examples.txt
|
||||
pip install '.[dev]' --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
pip install '.[dev]'
|
||||
pip list
|
||||
|
||||
- name: Type check
|
||||
|
|
|
@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))
|
||||
|
||||
|
||||
-
|
||||
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -18,6 +18,7 @@ from collections import Counter
|
|||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
|
@ -80,13 +81,21 @@ from pytorch_lightning.utilities import (
|
|||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.imports import (
|
||||
_HOROVOD_AVAILABLE,
|
||||
_HPU_AVAILABLE,
|
||||
_IPU_AVAILABLE,
|
||||
_TORCH_GREATER_EQUAL_1_11,
|
||||
_TPU_AVAILABLE,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
||||
_LITERAL_WARN = Literal["warn"]
|
||||
|
||||
|
||||
class AcceleratorConnector:
|
||||
def __init__(
|
||||
|
@ -102,7 +111,7 @@ class AcceleratorConnector:
|
|||
sync_batchnorm: bool = False,
|
||||
benchmark: Optional[bool] = None,
|
||||
replace_sampler_ddp: bool = True,
|
||||
deterministic: bool = False,
|
||||
deterministic: Union[bool, _LITERAL_WARN] = False,
|
||||
auto_select_gpus: bool = False,
|
||||
num_processes: Optional[int] = None, # deprecated
|
||||
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
|
||||
|
@ -205,9 +214,12 @@ class AcceleratorConnector:
|
|||
# 6. Instantiate Strategy - Part 2
|
||||
self._lazy_init_strategy()
|
||||
|
||||
def _init_deterministic(self, deterministic: bool) -> None:
|
||||
def _init_deterministic(self, deterministic: Union[bool, _LITERAL_WARN]) -> None:
|
||||
self.deterministic = deterministic
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
if _TORCH_GREATER_EQUAL_1_11 and deterministic == "warn":
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
else:
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
if deterministic:
|
||||
# fixing non-deterministic part of horovod
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
|
||||
|
|
|
@ -65,7 +65,7 @@ from pytorch_lightning.strategies import ParallelStrategy, Strategy
|
|||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
||||
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
||||
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
|
||||
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
|
||||
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
|
||||
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
||||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
|
@ -173,7 +173,7 @@ class Trainer(
|
|||
resume_from_checkpoint: Optional[Union[Path, str]] = None,
|
||||
profiler: Optional[Union[BaseProfiler, str]] = None,
|
||||
benchmark: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
deterministic: Union[bool, _LITERAL_WARN] = False,
|
||||
reload_dataloaders_every_n_epochs: int = 0,
|
||||
auto_lr_find: Union[bool, str] = False,
|
||||
replace_sampler_ddp: bool = True,
|
||||
|
@ -257,6 +257,8 @@ class Trainer(
|
|||
Default: ``False``.
|
||||
|
||||
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
||||
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
||||
that don't support deterministic mode (requires Pytorch 1.11+).
|
||||
Default: ``False``.
|
||||
|
||||
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
|
||||
|
|
|
@ -699,7 +699,7 @@ def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_device
|
|||
Trainer(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("deterministic", [True, False])
|
||||
@pytest.mark.parametrize("deterministic", [True, False, pytest.param("warn", marks=RunIf(min_torch="1.11.0"))])
|
||||
def test_deterministic_init(deterministic):
|
||||
trainer = Trainer(accelerator="auto", deterministic=deterministic)
|
||||
assert trainer._accelerator_connector.deterministic == deterministic
|
||||
|
|
Loading…
Reference in New Issue