load_spawn_weights only in proc rank 0 (#1385)
Co-authored-by: Alexander Reshytko <areshytko@Alexanders-MacBook-Pro.local>
This commit is contained in:
parent
4ed3027309
commit
9754c5da55
|
@ -363,15 +363,19 @@ class TrainerDDPMixin(ABC):
|
|||
:param model:
|
||||
:return:
|
||||
"""
|
||||
# load weights saved in ddp
|
||||
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
|
||||
loaded_model = original_model.__class__.load_from_checkpoint(path)
|
||||
|
||||
# copy loaded weights to old model
|
||||
original_model.load_state_dict(loaded_model.state_dict())
|
||||
loaded_model = original_model
|
||||
|
||||
# remove ddp weights
|
||||
os.remove(path)
|
||||
if self.proc_rank == 0:
|
||||
# load weights saved in ddp
|
||||
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
|
||||
loaded_model = original_model.__class__.load_from_checkpoint(path)
|
||||
|
||||
# copy loaded weights to old model
|
||||
original_model.load_state_dict(loaded_model.state_dict())
|
||||
|
||||
# remove ddp weights
|
||||
os.remove(path)
|
||||
|
||||
return loaded_model
|
||||
|
||||
|
|
Loading…
Reference in New Issue