From c3614f1c07caf2c1519fcbe37140f2aeef7a8308 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 22 Oct 2021 02:31:56 +0530 Subject: [PATCH] Fix: skip importing DistributedOptimizer for Windows (#10071) --- docs/source/advanced/mixed_precision.rst | 2 +- pytorch_lightning/core/lightning.py | 3 ++- pytorch_lightning/plugins/training_type/ddp.py | 8 ++++++-- pytorch_lightning/utilities/__init__.py | 1 + tests/core/test_lightning_module.py | 3 ++- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/source/advanced/mixed_precision.rst b/docs/source/advanced/mixed_precision.rst index 0954705240..1c98f663ed 100644 --- a/docs/source/advanced/mixed_precision.rst +++ b/docs/source/advanced/mixed_precision.rst @@ -50,7 +50,7 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision. .. testcode:: - :skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 + :skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available() Trainer(gpus=1, precision="bf16") diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6f03fd9369..4546cc0c80 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,6 +38,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import ( + _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10, GradClipAlgorithmType, rank_zero_deprecation, @@ -2041,7 +2042,7 @@ class LightningModule( These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ - if not _TORCH_GREATER_EQUAL_DEV_1_10: + if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS: return from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 64fc1a5a97..4499c1d7df 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -42,6 +42,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( _FAIRSCALE_AVAILABLE, _HYDRA_AVAILABLE, + _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, @@ -57,7 +58,9 @@ from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_10: - from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer + if not _IS_WINDOWS: + from torch.distributed.optim import DistributedOptimizer + from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS @@ -333,8 +336,9 @@ class DDPPlugin(ParallelPlugin): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer + is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False if ( - isinstance(optimizer, DistributedOptimizer) + is_distributed_optimizer or isinstance(optimizer, ZeroRedundancyOptimizer) or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)) ): diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 94d9159c9f..bc19aa1366 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -38,6 +38,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _HYDRA_EXPERIMENTAL_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, + _IS_WINDOWS, _JSONARGPARSE_AVAILABLE, _module_available, _OMEGACONF_AVAILABLE, diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 18260339d8..135e437c4b 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -21,7 +21,7 @@ from torch.optim import Adam, SGD from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10 +from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -315,6 +315,7 @@ class BoringModelWithShardedTensor(BoringModel): @pytest.mark.skipif( not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`" ) +@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows") def test_sharded_tensor_state_dict(tmpdir, single_process_pg): spec = dist._sharding_spec.ChunkShardingSpec( dim=0,