load_spawn_weights only in proc rank 0 (#1385)
Co-authored-by: Alexander Reshytko <areshytko@Alexanders-MacBook-Pro.local>
This commit is contained in:
parent
4ed3027309
commit
9754c5da55
|
@ -363,15 +363,19 @@ class TrainerDDPMixin(ABC):
|
||||||
:param model:
|
:param model:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# load weights saved in ddp
|
|
||||||
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
|
|
||||||
loaded_model = original_model.__class__.load_from_checkpoint(path)
|
|
||||||
|
|
||||||
# copy loaded weights to old model
|
loaded_model = original_model
|
||||||
original_model.load_state_dict(loaded_model.state_dict())
|
|
||||||
|
|
||||||
# remove ddp weights
|
if self.proc_rank == 0:
|
||||||
os.remove(path)
|
# load weights saved in ddp
|
||||||
|
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
|
||||||
|
loaded_model = original_model.__class__.load_from_checkpoint(path)
|
||||||
|
|
||||||
|
# copy loaded weights to old model
|
||||||
|
original_model.load_state_dict(loaded_model.state_dict())
|
||||||
|
|
||||||
|
# remove ddp weights
|
||||||
|
os.remove(path)
|
||||||
|
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue