diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 934938c63f..d1dfb3eec3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess +from contextlib import contextmanager from copy import copy, deepcopy import numpy as np @@ -655,6 +655,7 @@ class TrainLoop: # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() + should_accumulate = not (accumulation_done or is_final_batch) # lightning module hook splits = self.tbptt_split_batch(batch) @@ -675,13 +676,17 @@ class TrainLoop: model = self.trainer.get_model() model.toggle_optimizer(optimizer, opt_idx) - if not (accumulation_done or is_final_batch): + if should_accumulate: # For gradient accumulation # ------------------- # calculate loss (train step + train step end) # ------------------- - self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + + # perform dpp sync only when performing optimizer_step + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) + batch_outputs = self._process_closure_result( batch_callback_metrics=batch_callback_metrics, batch_log_metrics=batch_log_metrics, @@ -695,7 +700,6 @@ class TrainLoop: # gradient update with accumulated gradients else: - if self.automatic_optimization: def train_step_and_backward_closure(): @@ -760,6 +764,13 @@ class TrainLoop: ) return result + @contextmanager + def block_ddp_sync_behaviour(self): + if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel): + yield from self.trainer.model.no_sync() + else: + yield + def _process_closure_result( self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int ) -> list: