diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 50941e7b2f..98e5f1e774 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -363,15 +363,19 @@ class TrainerDDPMixin(ABC): :param model: :return: """ - # load weights saved in ddp - path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') - loaded_model = original_model.__class__.load_from_checkpoint(path) - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) + loaded_model = original_model - # remove ddp weights - os.remove(path) + if self.proc_rank == 0: + # load weights saved in ddp + path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') + loaded_model = original_model.__class__.load_from_checkpoint(path) + + # copy loaded weights to old model + original_model.load_state_dict(loaded_model.state_dict()) + + # remove ddp weights + os.remove(path) return loaded_model