From 908b9eebc70b846f0d746769562799ec56c5f235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 30 Aug 2021 16:56:16 +0200 Subject: [PATCH] move `block_ddp_sync_behaviour` to utilities (#9192) --- CHANGELOG.md | 4 +++ pytorch_lightning/core/optimizer.py | 5 ++- .../loops/batch/training_batch_loop.py | 32 +++---------------- pytorch_lightning/loops/utilities.py | 25 +++++++++++++-- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 588dd5772a..e1fe424b25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -145,6 +145,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161)) + +- Moved `block_ddp_sync_behaviour` out of `TrainingBatchLoop` to loop utilities ([#9192](https://github.com/PyTorchLightning/pytorch-lightning/pull/9192)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index d8d51e76dc..362c160559 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -116,7 +116,10 @@ class LightningOptimizer: during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. """ - with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad): + # local import here to avoid circular import + from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior + + with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)): self._toggle_model() yield self._untoggle_model() diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 3f94a01816..8da61141de 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -13,10 +13,9 @@ # limitations under the License. from collections import OrderedDict -from contextlib import contextmanager from copy import copy from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch @@ -28,11 +27,11 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.closure import Closure, ClosureResult from pytorch_lightning.loops.utilities import ( + _block_parallel_sync_behavior, _check_training_step_output, _process_training_step_output, check_finite_loss, ) -from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm @@ -186,9 +185,8 @@ class TrainingBatchLoop(Loop): # ------------------- # calculate loss (train step + train step end) # ------------------- - # automatic_optimization=True: perform ddp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): + # automatic_optimization: perform ddp sync only when performing optimizer_step + with _block_parallel_sync_behavior(self._trainer): closure() # ------------------------------ @@ -460,28 +458,6 @@ class TrainingBatchLoop(Loop): model = self.trainer.lightning_module model.untoggle_optimizer(opt_idx) - @contextmanager - def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]: - """ - automatic_optimization = True - Blocks ddp sync gradients behaviour on backwards pass. - This is useful for skipping sync when accumulating gradients, reducing communication overhead - - automatic_optimization = False - do not block ddp gradient sync when using manual optimization - as gradients are needed within the training step - - Returns: - context manager with sync behaviour off - """ - if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and ( - self.trainer.lightning_module.automatic_optimization or should_block_sync - ): - with self.trainer.training_type_plugin.block_backward_sync(): - yield None - else: - yield None - def backward( self, loss: Tensor, diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index dd69640106..c4f4d201cb 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, Iterator, Mapping, Optional, Tuple +from contextlib import contextmanager +from typing import Any, Generator, Iterator, Mapping, Optional, Tuple import torch import pytorch_lightning as pl +from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -113,3 +114,23 @@ def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) else: dataloader_iter = iter(data_fetcher) return dataloader_iter + + +@contextmanager +def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]: + """ + Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`. + This is useful for example when when accumulating gradients to reduce communication when it is not needed. + + Args: + trainer: the trainer instance with a reference to a training type plugin + block: whether the context manager is enabled or not + + Returns: + context manager with sync behaviour off + """ + if isinstance(trainer.training_type_plugin, ParallelPlugin) and block: + with trainer.training_type_plugin.block_backward_sync(): + yield None + else: + yield None