.fit() returns last not best weights in ddp_spawn (#2565)

* added base tests for tpu

* added base tests for tpu

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint

* enable none checkpoint
This commit is contained in:
William Falcon 2020-07-09 11:36:21 -04:00 committed by GitHub
parent e1bc208f66
commit 4bbcfa04a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 7 deletions

View File

@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC):
num_nodes: int
node_rank: int
tpu_cores: int
testing: bool
@property
@abstractmethod
@ -555,15 +556,35 @@ class TrainerDDPMixin(ABC):
# continue training routine
results = self.run_pretrain_routine(model)
# persist info in ddp_spawn
self.__transfer_ddp_spawn_state_on_fit_end(model, q, results)
# clean up memory
torch.cuda.empty_cache()
if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results
def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
if not self.distributed_backend in ['ddp_spawn', 'ddp_cpu']:
return
# track the best model path
best_model_path = None
if self.checkpoint_callback is not None:
best_model_path = self.checkpoint_callback.best_model_path
if self.global_rank == 0 and q is not None:
q.put(self.checkpoint_callback.best_model_path)
rank_zero_warn('cleaning up ddp environment...')
q.put(best_model_path)
q.put(results)
if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
return results
# save the last weights
last_path = None
if not self.testing:
last_path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
torch.save(model.state_dict(), last_path)
q.put(last_path)
def save_spawn_weights(self, model):
"""
@ -574,6 +595,7 @@ class TrainerDDPMixin(ABC):
if self.is_global_zero:
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)
return path
def load_spawn_weights(self, original_model):
"""

View File

@ -35,7 +35,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
import warnings
# warnings to ignore
# warnings to ignore in trainer
warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, '
'please use torch.distributed.ReduceOp instead')
@ -1063,9 +1063,14 @@ class Trainer(
# restore main state with best weights
best_path = q.get()
results = q.get()
if best_path is not None and len(best_path) > 0:
self.checkpoint_callback.best_model_path = best_path
model.load_from_checkpoint(best_path)
last_path = q.get()
# transfer back the best path to the trainer
self.checkpoint_callback.best_model_path = best_path
# load last weights
if last_path is not None and not self.testing:
torch.load(last_path, map_location=lambda storage, loc: storage)
self.model = model
return results

View File

@ -23,9 +23,16 @@ def test_single_gpu_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results
old_weights = model.c_d1.weight.clone().detach().cpu()
results = trainer.test(model)
assert 'test_acc' in results
# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
assert torch.all(torch.eq(old_weights, new_weights))
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_dp_test(tmpdir):
@ -45,9 +52,16 @@ def test_dp_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results
old_weights = model.c_d1.weight.clone().detach().cpu()
results = trainer.test(model)
assert 'test_acc' in results
# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
assert torch.all(torch.eq(old_weights, new_weights))
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_ddp_spawn_test(tmpdir):
@ -67,5 +81,12 @@ def test_ddp_spawn_test(tmpdir):
results = trainer.test()
assert 'test_acc' in results
old_weights = model.c_d1.weight.clone().detach().cpu()
results = trainer.test(model)
assert 'test_acc' in results
# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
assert torch.all(torch.eq(old_weights, new_weights))