From 68a1e522925f1cdfef33ff98f7ec5cc3f780c86c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jun 2020 20:15:02 -0400 Subject: [PATCH] added barrier (#2245) * added barrier * blank line * added barrier * added barrier * made fx public Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/trainer.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0f76c07229..35027cf3c1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -851,16 +851,18 @@ class Trainer( if self.is_function_implemented('on_fit_start'): model.on_fit_start() - self.setup('fit') - if self.is_function_implemented('setup'): - model.setup('fit') - # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): model.prepare_data() self._is_data_prepared = True + self.barrier('fit_prepare_data') + + self.setup('fit') + if self.is_function_implemented('setup'): + model.setup('fit') + # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): @@ -1150,6 +1152,8 @@ class Trainer( model_ref = self.model if model is None else model model_ref.setup('test') + self.barrier('test_setup') + if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.') @@ -1251,6 +1255,14 @@ class Trainer( raise MisconfigurationException('You have defined `test_step()` but did not' ' implement `test_dataloader` nor passed in `.test(test_dataloader)`.') + def barrier(self, name): + if self.use_ddp or self.use_ddp2: + torch_distrib.barrier() + + if self.on_tpu and XLA_AVAILABLE: + # wait for all processes to catch up + torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}') + class _PatchDataLoader(object): r"""