move `block_ddp_sync_behaviour` to utilities (#9192)
This commit is contained in:
parent
f79993a705
commit
908b9eebc7
|
@ -145,6 +145,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161))
|
||||
|
||||
|
||||
- Moved `block_ddp_sync_behaviour` out of `TrainingBatchLoop` to loop utilities ([#9192](https://github.com/PyTorchLightning/pytorch-lightning/pull/9192))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
||||
|
|
|
@ -116,7 +116,10 @@ class LightningOptimizer:
|
|||
during the accumulation phase.
|
||||
Setting `sync_grad` to False will block this synchronization and improve performance.
|
||||
"""
|
||||
with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
|
||||
# local import here to avoid circular import
|
||||
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
|
||||
|
||||
with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)):
|
||||
self._toggle_model()
|
||||
yield
|
||||
self._untoggle_model()
|
||||
|
|
|
@ -13,10 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -28,11 +27,11 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
|||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_block_parallel_sync_behavior,
|
||||
_check_training_step_output,
|
||||
_process_training_step_output,
|
||||
check_finite_loss,
|
||||
)
|
||||
from pytorch_lightning.plugins import ParallelPlugin
|
||||
from pytorch_lightning.trainer.progress import OptimizationProgress
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
|
||||
|
@ -186,9 +185,8 @@ class TrainingBatchLoop(Loop):
|
|||
# -------------------
|
||||
# calculate loss (train step + train step end)
|
||||
# -------------------
|
||||
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
|
||||
# automatic_optimization=False: don't block synchronization here
|
||||
with self.block_ddp_sync_behaviour():
|
||||
# automatic_optimization: perform ddp sync only when performing optimizer_step
|
||||
with _block_parallel_sync_behavior(self._trainer):
|
||||
closure()
|
||||
|
||||
# ------------------------------
|
||||
|
@ -460,28 +458,6 @@ class TrainingBatchLoop(Loop):
|
|||
model = self.trainer.lightning_module
|
||||
model.untoggle_optimizer(opt_idx)
|
||||
|
||||
@contextmanager
|
||||
def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]:
|
||||
"""
|
||||
automatic_optimization = True
|
||||
Blocks ddp sync gradients behaviour on backwards pass.
|
||||
This is useful for skipping sync when accumulating gradients, reducing communication overhead
|
||||
|
||||
automatic_optimization = False
|
||||
do not block ddp gradient sync when using manual optimization
|
||||
as gradients are needed within the training step
|
||||
|
||||
Returns:
|
||||
context manager with sync behaviour off
|
||||
"""
|
||||
if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (
|
||||
self.trainer.lightning_module.automatic_optimization or should_block_sync
|
||||
):
|
||||
with self.trainer.training_type_plugin.block_backward_sync():
|
||||
yield None
|
||||
else:
|
||||
yield None
|
||||
|
||||
def backward(
|
||||
self,
|
||||
loss: Tensor,
|
||||
|
|
|
@ -11,12 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Iterator, Mapping, Optional, Tuple
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Iterator, Mapping, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins import ParallelPlugin
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -113,3 +114,23 @@ def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int)
|
|||
else:
|
||||
dataloader_iter = iter(data_fetcher)
|
||||
return dataloader_iter
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
|
||||
"""
|
||||
Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`.
|
||||
This is useful for example when when accumulating gradients to reduce communication when it is not needed.
|
||||
|
||||
Args:
|
||||
trainer: the trainer instance with a reference to a training type plugin
|
||||
block: whether the context manager is enabled or not
|
||||
|
||||
Returns:
|
||||
context manager with sync behaviour off
|
||||
"""
|
||||
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
|
||||
with trainer.training_type_plugin.block_backward_sync():
|
||||
yield None
|
||||
else:
|
||||
yield None
|
||||
|
|
Loading…
Reference in New Issue