Support deterministic="warn" in Trainer for Pytorch 1.11+ (#12588)

Co-authored-by: carmocca <carlossmocholi@gmail.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
Wei Ji 2022-04-27 08:05:26 -04:00 committed by GitHub
parent a41486245a
commit 6490996b39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 22 deletions

View File

@ -13,17 +13,10 @@ concurrency:
jobs:
mypy:
runs-on: ubuntu-20.04
#strategy:
# fail-fast: false
# matrix:
# include:
# - {python-version: "3.8", pytorch-version: "1.8"}
# - {python-version: "3.9", pytorch-version: "1.10"}
steps:
- uses: actions/checkout@master
- uses: actions/setup-python@v2
with:
# python-version: ${{ matrix.python-version }}
python-version: 3.9
# Note: This uses an internal pip API and may not always work
@ -37,15 +30,10 @@ jobs:
${{ runner.os }}-pip-
- name: Install dependencies
env:
# TORCH_VERSION: ${{ matrix.pytorch-version }}
TORCH_VERSION: "1.10"
run: |
pip install "torch==$TORCH_VERSION" --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
# adjust versions according installed Torch version
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python ./requirements/adjust-versions.py requirements/extra.txt
python ./requirements/adjust-versions.py requirements/examples.txt
pip install '.[dev]' --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install '.[dev]'
pip list
- name: Type check

View File

@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))
-
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
-

View File

@ -18,6 +18,7 @@ from collections import Counter
from typing import Dict, List, Optional, Union
import torch
from typing_extensions import Literal
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
@ -80,13 +81,21 @@ from pytorch_lightning.utilities import (
rank_zero_warn,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.imports import (
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_11,
_TPU_AVAILABLE,
)
log = logging.getLogger(__name__)
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
_LITERAL_WARN = Literal["warn"]
class AcceleratorConnector:
def __init__(
@ -102,7 +111,7 @@ class AcceleratorConnector:
sync_batchnorm: bool = False,
benchmark: Optional[bool] = None,
replace_sampler_ddp: bool = True,
deterministic: bool = False,
deterministic: Union[bool, _LITERAL_WARN] = False,
auto_select_gpus: bool = False,
num_processes: Optional[int] = None, # deprecated
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
@ -205,9 +214,12 @@ class AcceleratorConnector:
# 6. Instantiate Strategy - Part 2
self._lazy_init_strategy()
def _init_deterministic(self, deterministic: bool) -> None:
def _init_deterministic(self, deterministic: Union[bool, _LITERAL_WARN]) -> None:
self.deterministic = deterministic
torch.use_deterministic_algorithms(deterministic)
if _TORCH_GREATER_EQUAL_1_11 and deterministic == "warn":
torch.use_deterministic_algorithms(True, warn_only=True)
else:
torch.use_deterministic_algorithms(deterministic)
if deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383

View File

@ -65,7 +65,7 @@ from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
@ -173,7 +173,7 @@ class Trainer(
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: Optional[bool] = None,
deterministic: bool = False,
deterministic: Union[bool, _LITERAL_WARN] = False,
reload_dataloaders_every_n_epochs: int = 0,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
@ -257,6 +257,8 @@ class Trainer(
Default: ``False``.
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
that don't support deterministic mode (requires Pytorch 1.11+).
Default: ``False``.
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,

View File

@ -699,7 +699,7 @@ def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_device
Trainer(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator)
@pytest.mark.parametrize("deterministic", [True, False])
@pytest.mark.parametrize("deterministic", [True, False, pytest.param("warn", marks=RunIf(min_torch="1.11.0"))])
def test_deterministic_init(deterministic):
trainer = Trainer(accelerator="auto", deterministic=deterministic)
assert trainer._accelerator_connector.deterministic == deterministic