added barrier (#2245)
* added barrier * blank line * added barrier * added barrier * made fx public Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
0b0c292cb9
commit
68a1e52292
|
@ -851,16 +851,18 @@ class Trainer(
|
|||
if self.is_function_implemented('on_fit_start'):
|
||||
model.on_fit_start()
|
||||
|
||||
self.setup('fit')
|
||||
if self.is_function_implemented('setup'):
|
||||
model.setup('fit')
|
||||
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
if self.can_prepare_data():
|
||||
model.prepare_data()
|
||||
self._is_data_prepared = True
|
||||
|
||||
self.barrier('fit_prepare_data')
|
||||
|
||||
self.setup('fit')
|
||||
if self.is_function_implemented('setup'):
|
||||
model.setup('fit')
|
||||
|
||||
# Run auto batch size scaling
|
||||
if self.auto_scale_batch_size:
|
||||
if isinstance(self.auto_scale_batch_size, bool):
|
||||
|
@ -1150,6 +1152,8 @@ class Trainer(
|
|||
model_ref = self.model if model is None else model
|
||||
model_ref.setup('test')
|
||||
|
||||
self.barrier('test_setup')
|
||||
|
||||
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
|
||||
raise MisconfigurationException(
|
||||
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')
|
||||
|
@ -1251,6 +1255,14 @@ class Trainer(
|
|||
raise MisconfigurationException('You have defined `test_step()` but did not'
|
||||
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.')
|
||||
|
||||
def barrier(self, name):
|
||||
if self.use_ddp or self.use_ddp2:
|
||||
torch_distrib.barrier()
|
||||
|
||||
if self.on_tpu and XLA_AVAILABLE:
|
||||
# wait for all processes to catch up
|
||||
torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}')
|
||||
|
||||
|
||||
class _PatchDataLoader(object):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue