diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 1b9a233184..e60a0d95cc 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -612,7 +612,7 @@ class Trainer(TrainerIO): # enable cluster checkpointing # also restores training state - if self.cluster is not None and self.proc_rank == 0: # pragma: no cover + if self.cluster is not None: # pragma: no cover self.enable_auto_hpc_walltime_manager() # --------------------------- diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 83fc1aa3af..39c5ae7b70 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -102,13 +102,16 @@ class TrainerIO(object): return # allow test tube to handle model check pointing automatically - self.cluster.set_checkpoint_save_function( - self.hpc_save, - kwargs={ - 'folderpath': self.checkpoint_callback.filepath, - 'experiment': self.experiment - } - ) + # only if proc 0 so we don't trigger world_size resubmits + if self.proc_rank == 0: + self.cluster.set_checkpoint_save_function( + self.hpc_save, + kwargs={ + 'folderpath': self.checkpoint_callback.filepath, + 'experiment': self.experiment + } + ) + self.cluster.set_checkpoint_load_function( self.hpc_load, kwargs={