ref: remove weight loading hack for ddp_cpu (#3808)

This commit is contained in:
William Falcon 2020-10-02 19:28:50 -04:00 committed by GitHub
parent afa43837a4
commit a28528cc8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 17 deletions

View File

@ -118,7 +118,8 @@ def cli_main():
# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
result = trainer.test(test_dataloaders=test_loader)
print(result)
if __name__ == '__main__':

View File

@ -62,10 +62,9 @@ class DDPCPUSpawnBackend(Accelerator):
# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()
# recover the weights of the processes trained in the children
self.__recover_child_process_weights(model, best_path, last_path)
self.__recover_child_process_weights(model, best_path)
return results
def ddp_train(self, process_idx, mp_queue, model):
@ -187,16 +186,10 @@ class DDPCPUSpawnBackend(Accelerator):
device_ids = None
return device_ids
def __recover_child_process_weights(self, model, best_path, last_path):
def __recover_child_process_weights(self, model, best_path):
# transfer back the best path to the trainer
if self.trainer.checkpoint_callback:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score
# load last weights
if last_path is not None and not self.trainer.testing:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)
self.trainer.model = model
@ -211,10 +204,3 @@ class DDPCPUSpawnBackend(Accelerator):
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)
# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)