diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c5ffd75bd..f66a3f69b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -230,6 +230,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787)) +- Use a unique filename to save temp ckpt in tuner ([#96827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 7e1181ebdb..e2cc0aab12 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -13,6 +13,7 @@ # limitations under the License import logging import os +import uuid from typing import Optional, Tuple import pytorch_lightning as pl @@ -63,7 +64,7 @@ def scale_batch_size( __scale_batch_reset_params(trainer, model, steps_per_trial) # Save initial model, that is loaded after batch size is found - save_path = os.path.join(trainer.default_root_dir, "scale_batch_size_temp_model.ckpt") + save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt") trainer.save_checkpoint(str(save_path)) if trainer.progress_bar_callback: diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index e44f316863..60e4c76418 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -14,6 +14,7 @@ import importlib import logging import os +import uuid from functools import wraps from typing import Callable, Optional, Sequence @@ -208,7 +209,7 @@ def lr_find( if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) - save_path = os.path.join(trainer.default_root_dir, "lr_find_temp_model.ckpt") + save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt") __lr_finder_dump_params(trainer, model) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index d764afba5d..80fdf019b7 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -57,7 +57,7 @@ def test_model_reset_correctly(tmpdir): torch.eq(before_state_dict[key], after_state_dict[key]) ), "Model was not reset correctly after learning rate finder" - assert not os.path.exists(tmpdir / "lr_find_temp_model.ckpt") + assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model")) def test_trainer_reset_correctly(tmpdir): diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index a78ef828ca..c0d60e5bd8 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -91,6 +91,8 @@ def test_model_reset_correctly(tmpdir): torch.eq(before_state_dict[key], after_state_dict[key]) ), "Model was not reset correctly after scaling batch size" + assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model")) + def test_trainer_reset_correctly(tmpdir): """Check that all trainer parameters are reset correctly after scaling batch size."""