added barrier (#2245)

* added barrier

* blank line

* added barrier

* added barrier

* made fx public

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
William Falcon 2020-06-18 20:15:02 -04:00 committed by GitHub
parent 0b0c292cb9
commit 68a1e52292
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 4 deletions

View File

@ -851,16 +851,18 @@ class Trainer(
if self.is_function_implemented('on_fit_start'):
model.on_fit_start()
self.setup('fit')
if self.is_function_implemented('setup'):
model.setup('fit')
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
model.prepare_data()
self._is_data_prepared = True
self.barrier('fit_prepare_data')
self.setup('fit')
if self.is_function_implemented('setup'):
model.setup('fit')
# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
@ -1150,6 +1152,8 @@ class Trainer(
model_ref = self.model if model is None else model
model_ref.setup('test')
self.barrier('test_setup')
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')
@ -1251,6 +1255,14 @@ class Trainer(
raise MisconfigurationException('You have defined `test_step()` but did not'
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.')
def barrier(self, name):
if self.use_ddp or self.use_ddp2:
torch_distrib.barrier()
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}')
class _PatchDataLoader(object):
r"""