Only import PostLocalSGD related modules when it's needed (#10359)
* Only import PostLocalSGD related modules when it's needed * Only import PostLocalSGD related modules when it's needed * Only import PostLocalSGD related modules when it's needed
This commit is contained in:
parent
45f6a3b175
commit
bdc24e558a
|
@ -63,11 +63,6 @@ from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, Mi
|
|||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
if not _IS_WINDOWS:
|
||||
from torch.distributed.optim import DistributedOptimizer
|
||||
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.optim import OSS
|
||||
if _HYDRA_AVAILABLE:
|
||||
|
@ -75,9 +70,7 @@ if _HYDRA_AVAILABLE:
|
|||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -312,12 +305,11 @@ class DDPPlugin(ParallelPlugin):
|
|||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||
)
|
||||
|
||||
if (
|
||||
_TORCH_GREATER_EQUAL_1_10
|
||||
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
|
||||
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
|
||||
):
|
||||
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
|
||||
if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
|
||||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||
|
||||
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
|
||||
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
|
||||
|
||||
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
|
||||
optimizers = self.lightning_module.trainer.optimizers
|
||||
|
@ -325,6 +317,12 @@ class DDPPlugin(ParallelPlugin):
|
|||
raise ValueError(
|
||||
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
|
||||
)
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
if not _IS_WINDOWS:
|
||||
from torch.distributed.optim import DistributedOptimizer
|
||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||
|
||||
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
if isinstance(optimizer, LightningOptimizer):
|
||||
|
|
Loading…
Reference in New Issue