load_spawn_weights only in proc rank 0 (#1385)

Co-authored-by: Alexander Reshytko <areshytko@Alexanders-MacBook-Pro.local>
This commit is contained in:
areshytko 2020-04-06 17:17:16 +03:00 committed by GitHub
parent 4ed3027309
commit 9754c5da55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 7 deletions

View File

@ -363,15 +363,19 @@ class TrainerDDPMixin(ABC):
:param model:
:return:
"""
# load weights saved in ddp
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)
# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())
loaded_model = original_model
# remove ddp weights
os.remove(path)
if self.proc_rank == 0:
# load weights saved in ddp
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)
# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())
# remove ddp weights
os.remove(path)
return loaded_model