Only import PostLocalSGD related modules when it's needed (#10359)

* Only import PostLocalSGD related modules when it's needed

* Only import PostLocalSGD related modules when it's needed

* Only import PostLocalSGD related modules when it's needed
This commit is contained in:
four4fish 2021-11-07 18:05:44 -08:00 committed by GitHub
parent 45f6a3b175
commit bdc24e558a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 14 deletions

View File

@ -63,11 +63,6 @@ from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, Mi
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
if _TORCH_GREATER_EQUAL_1_10:
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _HYDRA_AVAILABLE:
@ -75,9 +70,7 @@ if _HYDRA_AVAILABLE:
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
if _TORCH_GREATER_EQUAL_1_10:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
import torch.distributed.algorithms.model_averaging.averagers as averagers
log = logging.getLogger(__name__)
@ -312,12 +305,11 @@ class DDPPlugin(ParallelPlugin):
ddp_comm_wrapper=self._ddp_comm_wrapper,
)
if (
_TORCH_GREATER_EQUAL_1_10
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
optimizers = self.lightning_module.trainer.optimizers
@ -325,6 +317,12 @@ class DDPPlugin(ParallelPlugin):
raise ValueError(
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
)
if _TORCH_GREATER_EQUAL_1_10:
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.algorithms.model_averaging.averagers as averagers
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):