move `block_ddp_sync_behaviour` to utilities (#9192)

This commit is contained in:
Adrian Wälchli 2021-08-30 16:56:16 +02:00 committed by GitHub
parent f79993a705
commit 908b9eebc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 31 deletions

View File

@ -145,6 +145,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161)) - Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161))
- Moved `block_ddp_sync_behaviour` out of `TrainingBatchLoop` to loop utilities ([#9192](https://github.com/PyTorchLightning/pytorch-lightning/pull/9192))
### Deprecated ### Deprecated
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`

View File

@ -116,7 +116,10 @@ class LightningOptimizer:
during the accumulation phase. during the accumulation phase.
Setting `sync_grad` to False will block this synchronization and improve performance. Setting `sync_grad` to False will block this synchronization and improve performance.
""" """
with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad): # local import here to avoid circular import
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)):
self._toggle_model() self._toggle_model()
yield yield
self._untoggle_model() self._untoggle_model()

View File

@ -13,10 +13,9 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from contextlib import contextmanager
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -28,11 +27,11 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.utilities import ( from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_check_training_step_output, _check_training_step_output,
_process_training_step_output, _process_training_step_output,
check_finite_loss, check_finite_loss,
) )
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
@ -186,9 +185,8 @@ class TrainingBatchLoop(Loop):
# ------------------- # -------------------
# calculate loss (train step + train step end) # calculate loss (train step + train step end)
# ------------------- # -------------------
# automatic_optimization=True: perform ddp sync only when performing optimizer_step # automatic_optimization: perform ddp sync only when performing optimizer_step
# automatic_optimization=False: don't block synchronization here with _block_parallel_sync_behavior(self._trainer):
with self.block_ddp_sync_behaviour():
closure() closure()
# ------------------------------ # ------------------------------
@ -460,28 +458,6 @@ class TrainingBatchLoop(Loop):
model = self.trainer.lightning_module model = self.trainer.lightning_module
model.untoggle_optimizer(opt_idx) model.untoggle_optimizer(opt_idx)
@contextmanager
def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]:
"""
automatic_optimization = True
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
automatic_optimization = False
do not block ddp gradient sync when using manual optimization
as gradients are needed within the training step
Returns:
context manager with sync behaviour off
"""
if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (
self.trainer.lightning_module.automatic_optimization or should_block_sync
):
with self.trainer.training_type_plugin.block_backward_sync():
yield None
else:
yield None
def backward( def backward(
self, self,
loss: Tensor, loss: Tensor,

View File

@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from typing import Any, Iterator, Mapping, Optional, Tuple from typing import Any, Generator, Iterator, Mapping, Optional, Tuple
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -113,3 +114,23 @@ def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int)
else: else:
dataloader_iter = iter(data_fetcher) dataloader_iter = iter(data_fetcher)
return dataloader_iter return dataloader_iter
@contextmanager
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
"""
Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`.
This is useful for example when when accumulating gradients to reduce communication when it is not needed.
Args:
trainer: the trainer instance with a reference to a training type plugin
block: whether the context manager is enabled or not
Returns:
context manager with sync behaviour off
"""
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
with trainer.training_type_plugin.block_backward_sync():
yield None
else:
yield None