From ffb8d4754710ce7473f7e14542af8f1594962b5d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Feb 2021 15:31:58 +0000 Subject: [PATCH] corrent bugs --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 5 ++++- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 + pytorch_lightning/trainer/training_loop.py | 2 ++ tests/models/test_tpu.py | 4 ++-- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 6f251eb369..943f2fc86c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -171,6 +171,9 @@ class DDPSpawnPlugin(ParallelPlugin): return None return [self.root_device.index] + def on_save(self, checkpoint: dict) -> dict: + return checkpoint + def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path @@ -183,7 +186,7 @@ class DDPSpawnPlugin(ParallelPlugin): # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(self.lightning_module.state_dict(), last_path) + atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 0f516e2b0b..1f1030d75a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -95,6 +95,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): Recommended on XLA Guide: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ + print("on_save") return move_data_to_device(checkpoint, torch.device("cpu")) def broadcast(self, obj: object, src: int = 0) -> object: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c5f9f56a00..6fe57348b4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -544,6 +544,8 @@ class TrainLoop: self.trainer.batch_idx = batch_idx + print("batch_idx") + # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 98a02d730e..65fd88f45a 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -87,8 +87,8 @@ def test_model_tpu_cores_8(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, tpu_cores=8, - limit_train_batches=0.4, - limit_val_batches=0.4, + limit_train_batches=4, + limit_val_batches=4, ) model = EvalModelTemplate()