Fix: skip importing DistributedOptimizer for Windows (#10071)

This commit is contained in:
Kaushik B 2021-10-22 02:31:56 +05:30 committed by GitHub
parent 454e93bace
commit c3614f1c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 12 additions and 5 deletions

View File

@ -50,7 +50,7 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain
Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.
.. testcode::
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available()
Trainer(gpus=1, precision="bf16")

View File

@ -38,6 +38,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import (
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_DEV_1_10,
GradClipAlgorithmType,
rank_zero_deprecation,
@ -2041,7 +2042,7 @@ class LightningModule(
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if not _TORCH_GREATER_EQUAL_DEV_1_10:
if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS:
return
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook

View File

@ -42,6 +42,7 @@ from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_FAIRSCALE_AVAILABLE,
_HYDRA_AVAILABLE,
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
@ -57,7 +58,9 @@ from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
@ -333,8 +336,9 @@ class DDPPlugin(ParallelPlugin):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False
if (
isinstance(optimizer, DistributedOptimizer)
is_distributed_optimizer
or isinstance(optimizer, ZeroRedundancyOptimizer)
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
):

View File

@ -38,6 +38,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_HYDRA_EXPERIMENTAL_AVAILABLE,
_IPU_AVAILABLE,
_IS_INTERACTIVE,
_IS_WINDOWS,
_JSONARGPARSE_AVAILABLE,
_module_available,
_OMEGACONF_AVAILABLE,

View File

@ -21,7 +21,7 @@ from torch.optim import Adam, SGD
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@ -315,6 +315,7 @@ class BoringModelWithShardedTensor(BoringModel):
@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`"
)
@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows")
def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
spec = dist._sharding_spec.ChunkShardingSpec(
dim=0,