From 9754c5da55059dd89cf0a4fd582fe5df9449bbe5 Mon Sep 17 00:00:00 2001 From: areshytko Date: Mon, 6 Apr 2020 17:17:16 +0300 Subject: [PATCH] load_spawn_weights only in proc rank 0 (#1385) Co-authored-by: Alexander Reshytko --- .../trainer/distrib_data_parallel.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 50941e7b2f..98e5f1e774 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -363,15 +363,19 @@ class TrainerDDPMixin(ABC): :param model: :return: """ - # load weights saved in ddp - path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') - loaded_model = original_model.__class__.load_from_checkpoint(path) - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) + loaded_model = original_model - # remove ddp weights - os.remove(path) + if self.proc_rank == 0: + # load weights saved in ddp + path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') + loaded_model = original_model.__class__.load_from_checkpoint(path) + + # copy loaded weights to old model + original_model.load_state_dict(loaded_model.state_dict()) + + # remove ddp weights + os.remove(path) return loaded_model