Use a unique filename to save temp ckpt in tuner (#9682)

* unique filename

* chlog

* update tests
This commit is contained in:
Rohit Gupta 2021-09-25 16:58:51 +05:30 committed by GitHub
parent 5395cebc51
commit a3def9d228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 10 additions and 3 deletions

View File

@ -230,6 +230,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787))
- Use a unique filename to save temp ckpt in tuner ([#96827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682))
### Deprecated
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`

View File

@ -13,6 +13,7 @@
# limitations under the License
import logging
import os
import uuid
from typing import Optional, Tuple
import pytorch_lightning as pl
@ -63,7 +64,7 @@ def scale_batch_size(
__scale_batch_reset_params(trainer, model, steps_per_trial)
# Save initial model, that is loaded after batch size is found
save_path = os.path.join(trainer.default_root_dir, "scale_batch_size_temp_model.ckpt")
save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
trainer.save_checkpoint(str(save_path))
if trainer.progress_bar_callback:

View File

@ -14,6 +14,7 @@
import importlib
import logging
import os
import uuid
from functools import wraps
from typing import Callable, Optional, Sequence
@ -208,7 +209,7 @@ def lr_find(
if update_attr:
lr_attr_name = _determine_lr_attr_name(trainer, model)
save_path = os.path.join(trainer.default_root_dir, "lr_find_temp_model.ckpt")
save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt")
__lr_finder_dump_params(trainer, model)

View File

@ -57,7 +57,7 @@ def test_model_reset_correctly(tmpdir):
torch.eq(before_state_dict[key], after_state_dict[key])
), "Model was not reset correctly after learning rate finder"
assert not os.path.exists(tmpdir / "lr_find_temp_model.ckpt")
assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model"))
def test_trainer_reset_correctly(tmpdir):

View File

@ -91,6 +91,8 @@ def test_model_reset_correctly(tmpdir):
torch.eq(before_state_dict[key], after_state_dict[key])
), "Model was not reset correctly after scaling batch size"
assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model"))
def test_trainer_reset_correctly(tmpdir):
"""Check that all trainer parameters are reset correctly after scaling batch size."""