From bbd81dfd55eb07c7ee07ee6c0b95f207d3ade353 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 29 Oct 2020 18:31:37 +0100 Subject: [PATCH] Skips DDP parameter sync (#4301) * ddp no-sync * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: ananthsub * Update training_loop.py * factor __enter__ and __exit__ out to separate context manager * delete _updated_model_last_step Co-authored-by: justusschock Co-authored-by: Teddy Koker Co-authored-by: ananthsub Co-authored-by: chaton Co-authored-by: Rohit Gupta --- pytorch_lightning/trainer/training_loop.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 934938c63f..d1dfb3eec3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess +from contextlib import contextmanager from copy import copy, deepcopy import numpy as np @@ -655,6 +655,7 @@ class TrainLoop: # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() + should_accumulate = not (accumulation_done or is_final_batch) # lightning module hook splits = self.tbptt_split_batch(batch) @@ -675,13 +676,17 @@ class TrainLoop: model = self.trainer.get_model() model.toggle_optimizer(optimizer, opt_idx) - if not (accumulation_done or is_final_batch): + if should_accumulate: # For gradient accumulation # ------------------- # calculate loss (train step + train step end) # ------------------- - self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + + # perform dpp sync only when performing optimizer_step + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + batch_outputs = self._process_closure_result( batch_callback_metrics=batch_callback_metrics, batch_log_metrics=batch_log_metrics, @@ -695,7 +700,6 @@ class TrainLoop: # gradient update with accumulated gradients else: - if self.automatic_optimization: def train_step_and_backward_closure(): @@ -760,6 +764,13 @@ class TrainLoop: ) return result + @contextmanager + def block_ddp_sync_behaviour(self): + if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel): + yield from self.trainer.model.no_sync() + else: + yield + def _process_closure_result( self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int ) -> list: