.fit() returns last not best weights in ddp_spawn (#2565)
* added base tests for tpu * added base tests for tpu * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
This commit is contained in:
parent
e1bc208f66
commit
4bbcfa04a3
|
@ -189,6 +189,7 @@ class TrainerDDPMixin(ABC):
|
|||
num_nodes: int
|
||||
node_rank: int
|
||||
tpu_cores: int
|
||||
testing: bool
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -555,15 +556,35 @@ class TrainerDDPMixin(ABC):
|
|||
# continue training routine
|
||||
results = self.run_pretrain_routine(model)
|
||||
|
||||
# persist info in ddp_spawn
|
||||
self.__transfer_ddp_spawn_state_on_fit_end(model, q, results)
|
||||
|
||||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
|
||||
return results
|
||||
|
||||
def __transfer_ddp_spawn_state_on_fit_end(self, model, q, results):
|
||||
if not self.distributed_backend in ['ddp_spawn', 'ddp_cpu']:
|
||||
return
|
||||
|
||||
# track the best model path
|
||||
best_model_path = None
|
||||
if self.checkpoint_callback is not None:
|
||||
best_model_path = self.checkpoint_callback.best_model_path
|
||||
|
||||
if self.global_rank == 0 and q is not None:
|
||||
q.put(self.checkpoint_callback.best_model_path)
|
||||
rank_zero_warn('cleaning up ddp environment...')
|
||||
q.put(best_model_path)
|
||||
q.put(results)
|
||||
|
||||
if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
|
||||
return results
|
||||
# save the last weights
|
||||
last_path = None
|
||||
if not self.testing:
|
||||
last_path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
|
||||
torch.save(model.state_dict(), last_path)
|
||||
q.put(last_path)
|
||||
|
||||
def save_spawn_weights(self, model):
|
||||
"""
|
||||
|
@ -574,6 +595,7 @@ class TrainerDDPMixin(ABC):
|
|||
if self.is_global_zero:
|
||||
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
|
||||
self.save_checkpoint(path)
|
||||
return path
|
||||
|
||||
def load_spawn_weights(self, original_model):
|
||||
"""
|
||||
|
|
|
@ -35,7 +35,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
|
||||
import warnings
|
||||
|
||||
# warnings to ignore
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, '
|
||||
'please use torch.distributed.ReduceOp instead')
|
||||
|
||||
|
@ -1063,9 +1063,14 @@ class Trainer(
|
|||
# restore main state with best weights
|
||||
best_path = q.get()
|
||||
results = q.get()
|
||||
if best_path is not None and len(best_path) > 0:
|
||||
self.checkpoint_callback.best_model_path = best_path
|
||||
model.load_from_checkpoint(best_path)
|
||||
last_path = q.get()
|
||||
|
||||
# transfer back the best path to the trainer
|
||||
self.checkpoint_callback.best_model_path = best_path
|
||||
|
||||
# load last weights
|
||||
if last_path is not None and not self.testing:
|
||||
torch.load(last_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
self.model = model
|
||||
return results
|
||||
|
|
|
@ -23,9 +23,16 @@ def test_single_gpu_test(tmpdir):
|
|||
results = trainer.test()
|
||||
assert 'test_acc' in results
|
||||
|
||||
old_weights = model.c_d1.weight.clone().detach().cpu()
|
||||
|
||||
results = trainer.test(model)
|
||||
assert 'test_acc' in results
|
||||
|
||||
# make sure weights didn't change
|
||||
new_weights = model.c_d1.weight.clone().detach().cpu()
|
||||
|
||||
assert torch.all(torch.eq(old_weights, new_weights))
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_dp_test(tmpdir):
|
||||
|
@ -45,9 +52,16 @@ def test_dp_test(tmpdir):
|
|||
results = trainer.test()
|
||||
assert 'test_acc' in results
|
||||
|
||||
old_weights = model.c_d1.weight.clone().detach().cpu()
|
||||
|
||||
results = trainer.test(model)
|
||||
assert 'test_acc' in results
|
||||
|
||||
# make sure weights didn't change
|
||||
new_weights = model.c_d1.weight.clone().detach().cpu()
|
||||
|
||||
assert torch.all(torch.eq(old_weights, new_weights))
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_ddp_spawn_test(tmpdir):
|
||||
|
@ -67,5 +81,12 @@ def test_ddp_spawn_test(tmpdir):
|
|||
results = trainer.test()
|
||||
assert 'test_acc' in results
|
||||
|
||||
old_weights = model.c_d1.weight.clone().detach().cpu()
|
||||
|
||||
results = trainer.test(model)
|
||||
assert 'test_acc' in results
|
||||
|
||||
# make sure weights didn't change
|
||||
new_weights = model.c_d1.weight.clone().detach().cpu()
|
||||
|
||||
assert torch.all(torch.eq(old_weights, new_weights))
|
||||
|
|
Loading…
Reference in New Issue