Fix checkpoint callback & Trainer.test(_) issue for TPUs (#6654)
* Fix checkpoint callback issue for TPUs * update changelog * add barrier * apply code suggestions * update trainer test * remove spaces * fix tpu tests * Apply suggestions from code review * add comment Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
b8ef52baa1
commit
2cbdc01256
|
@ -188,6 +188,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434))
|
||||
|
||||
|
||||
- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
|
||||
|
||||
|
||||
- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
|
||||
|
||||
|
||||
- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657))
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ import re
|
|||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
@ -109,13 +108,15 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
|
||||
# replace trainer save_checkpoint to use `xm.save`
|
||||
trainer.save_checkpoint = self.save_checkpoint
|
||||
self.barrier()
|
||||
self.barrier("pre-run-stage")
|
||||
|
||||
results = trainer.run_stage()
|
||||
|
||||
self.__save_end_of_training_weights(self.lightning_module)
|
||||
self.transfer_distrib_spawn_state_on_fit_end(results)
|
||||
|
||||
self.barrier("end-process")
|
||||
|
||||
def __save_end_of_training_weights(self, model: LightningModule) -> None:
|
||||
# when training ends on these platforms dump weights to get out of the main process
|
||||
if on_colab_kaggle():
|
||||
|
@ -126,11 +127,11 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self._model.to(xm.xla_device())
|
||||
|
||||
def barrier(self, name: Optional[str] = None) -> None:
|
||||
if torch_distrib.is_initialized():
|
||||
rendezvous(f"pl.Trainer.{name}")
|
||||
rendezvous(name)
|
||||
|
||||
def transfer_distrib_spawn_state_on_fit_end(self, results):
|
||||
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
|
||||
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
|
||||
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
|
||||
|
||||
if self.mp_queue is not None:
|
||||
rank_zero_warn("cleaning up ddp environment...")
|
||||
|
|
|
@ -57,7 +57,7 @@ from pytorch_lightning.trainer.states import TrainerState
|
|||
from pytorch_lightning.trainer.training_loop import TrainLoop
|
||||
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -983,7 +983,9 @@ class Trainer(
|
|||
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
|
||||
)
|
||||
|
||||
self.training_type_plugin.barrier()
|
||||
# only one process running at this point for TPUs, as spawn isn't triggered yet
|
||||
if not self._device_type == DeviceType.TPU:
|
||||
self.training_type_plugin.barrier()
|
||||
|
||||
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(ckpt['state_dict'])
|
||||
|
|
|
@ -357,13 +357,14 @@ def test_tpu_reduce():
|
|||
xmp.spawn(test_reduce, nprocs=8, start_method='fork')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("clip_val", [0, 10])
|
||||
@RunIf(tpu=True)
|
||||
@pl_multi_process_test
|
||||
@pytest.mark.parametrize("clip_val", [10])
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
|
||||
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
|
||||
"""
|
||||
Ensure that clip gradients is only called if the value is greater than 0.
|
||||
TODO: Fix (test fails with parametrize)
|
||||
"""
|
||||
tutils.reset_seed()
|
||||
trainer_options = dict(
|
||||
|
@ -383,3 +384,15 @@ def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
|
|||
mock_clip_grad_norm.assert_called()
|
||||
else:
|
||||
mock_clip_grad_norm.assert_not_called()
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
@pl_multi_process_test
|
||||
def test_if_test_works_with_checkpoint_false(tmpdir):
|
||||
"""Ensure that model trains properly when `checkpoint_callback` is set to False."""
|
||||
|
||||
# Train a model on TPU
|
||||
model = BoringModel()
|
||||
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
|
||||
trainer.fit(model)
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
|
Loading…
Reference in New Issue