move `block_ddp_sync_behaviour` to utilities (#9192)
This commit is contained in:
parent
f79993a705
commit
908b9eebc7
|
@ -145,6 +145,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161))
|
- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161))
|
||||||
|
|
||||||
|
|
||||||
|
- Moved `block_ddp_sync_behaviour` out of `TrainingBatchLoop` to loop utilities ([#9192](https://github.com/PyTorchLightning/pytorch-lightning/pull/9192))
|
||||||
|
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
||||||
|
|
|
@ -116,7 +116,10 @@ class LightningOptimizer:
|
||||||
during the accumulation phase.
|
during the accumulation phase.
|
||||||
Setting `sync_grad` to False will block this synchronization and improve performance.
|
Setting `sync_grad` to False will block this synchronization and improve performance.
|
||||||
"""
|
"""
|
||||||
with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
|
# local import here to avoid circular import
|
||||||
|
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
|
||||||
|
|
||||||
|
with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)):
|
||||||
self._toggle_model()
|
self._toggle_model()
|
||||||
yield
|
yield
|
||||||
self._untoggle_model()
|
self._untoggle_model()
|
||||||
|
|
|
@ -13,10 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import contextmanager
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -28,11 +27,11 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||||
from pytorch_lightning.loops.base import Loop
|
from pytorch_lightning.loops.base import Loop
|
||||||
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
||||||
from pytorch_lightning.loops.utilities import (
|
from pytorch_lightning.loops.utilities import (
|
||||||
|
_block_parallel_sync_behavior,
|
||||||
_check_training_step_output,
|
_check_training_step_output,
|
||||||
_process_training_step_output,
|
_process_training_step_output,
|
||||||
check_finite_loss,
|
check_finite_loss,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.plugins import ParallelPlugin
|
|
||||||
from pytorch_lightning.trainer.progress import OptimizationProgress
|
from pytorch_lightning.trainer.progress import OptimizationProgress
|
||||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||||
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
|
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
|
||||||
|
@ -186,9 +185,8 @@ class TrainingBatchLoop(Loop):
|
||||||
# -------------------
|
# -------------------
|
||||||
# calculate loss (train step + train step end)
|
# calculate loss (train step + train step end)
|
||||||
# -------------------
|
# -------------------
|
||||||
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
|
# automatic_optimization: perform ddp sync only when performing optimizer_step
|
||||||
# automatic_optimization=False: don't block synchronization here
|
with _block_parallel_sync_behavior(self._trainer):
|
||||||
with self.block_ddp_sync_behaviour():
|
|
||||||
closure()
|
closure()
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
|
@ -460,28 +458,6 @@ class TrainingBatchLoop(Loop):
|
||||||
model = self.trainer.lightning_module
|
model = self.trainer.lightning_module
|
||||||
model.untoggle_optimizer(opt_idx)
|
model.untoggle_optimizer(opt_idx)
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]:
|
|
||||||
"""
|
|
||||||
automatic_optimization = True
|
|
||||||
Blocks ddp sync gradients behaviour on backwards pass.
|
|
||||||
This is useful for skipping sync when accumulating gradients, reducing communication overhead
|
|
||||||
|
|
||||||
automatic_optimization = False
|
|
||||||
do not block ddp gradient sync when using manual optimization
|
|
||||||
as gradients are needed within the training step
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
context manager with sync behaviour off
|
|
||||||
"""
|
|
||||||
if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (
|
|
||||||
self.trainer.lightning_module.automatic_optimization or should_block_sync
|
|
||||||
):
|
|
||||||
with self.trainer.training_type_plugin.block_backward_sync():
|
|
||||||
yield None
|
|
||||||
else:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
def backward(
|
def backward(
|
||||||
self,
|
self,
|
||||||
loss: Tensor,
|
loss: Tensor,
|
||||||
|
|
|
@ -11,12 +11,13 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Iterator, Mapping, Optional, Tuple
|
from typing import Any, Generator, Iterator, Mapping, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.plugins import ParallelPlugin
|
||||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
@ -113,3 +114,23 @@ def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int)
|
||||||
else:
|
else:
|
||||||
dataloader_iter = iter(data_fetcher)
|
dataloader_iter = iter(data_fetcher)
|
||||||
return dataloader_iter
|
return dataloader_iter
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
|
||||||
|
"""
|
||||||
|
Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`.
|
||||||
|
This is useful for example when when accumulating gradients to reduce communication when it is not needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer: the trainer instance with a reference to a training type plugin
|
||||||
|
block: whether the context manager is enabled or not
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
context manager with sync behaviour off
|
||||||
|
"""
|
||||||
|
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
|
||||||
|
with trainer.training_type_plugin.block_backward_sync():
|
||||||
|
yield None
|
||||||
|
else:
|
||||||
|
yield None
|
||||||
|
|
Loading…
Reference in New Issue