ref: remove weight loading hack for ddp_cpu (#3808)
This commit is contained in:
parent
afa43837a4
commit
a28528cc8b
|
@ -118,7 +118,8 @@ def cli_main():
|
|||
# ------------
|
||||
# testing
|
||||
# ------------
|
||||
trainer.test(test_dataloaders=test_loader)
|
||||
result = trainer.test(test_dataloaders=test_loader)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -62,10 +62,9 @@ class DDPCPUSpawnBackend(Accelerator):
|
|||
# restore main state with best weights
|
||||
best_path = self.mp_queue.get()
|
||||
results = self.mp_queue.get()
|
||||
last_path = self.mp_queue.get()
|
||||
|
||||
# recover the weights of the processes trained in the children
|
||||
self.__recover_child_process_weights(model, best_path, last_path)
|
||||
self.__recover_child_process_weights(model, best_path)
|
||||
return results
|
||||
|
||||
def ddp_train(self, process_idx, mp_queue, model):
|
||||
|
@ -187,16 +186,10 @@ class DDPCPUSpawnBackend(Accelerator):
|
|||
device_ids = None
|
||||
return device_ids
|
||||
|
||||
def __recover_child_process_weights(self, model, best_path, last_path):
|
||||
def __recover_child_process_weights(self, model, best_path):
|
||||
# transfer back the best path to the trainer
|
||||
if self.trainer.checkpoint_callback:
|
||||
self.trainer.checkpoint_callback.best_model_path = best_path
|
||||
# todo, pass also best score
|
||||
|
||||
# load last weights
|
||||
if last_path is not None and not self.trainer.testing:
|
||||
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(ckpt)
|
||||
|
||||
self.trainer.model = model
|
||||
|
||||
|
@ -211,10 +204,3 @@ class DDPCPUSpawnBackend(Accelerator):
|
|||
# todo, pass complete checkpoint as state dictionary
|
||||
mp_queue.put(best_model_path)
|
||||
mp_queue.put(results)
|
||||
|
||||
# save the last weights
|
||||
last_path = None
|
||||
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
|
||||
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
|
||||
atomic_save(model.state_dict(), last_path)
|
||||
mp_queue.put(last_path)
|
||||
|
|
Loading…
Reference in New Issue