Use a unique filename to save temp ckpt in tuner (#9682)
* unique filename * chlog * update tests
This commit is contained in:
parent
5395cebc51
commit
a3def9d228
|
@ -230,6 +230,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787))
|
||||
|
||||
|
||||
- Use a unique filename to save temp ckpt in tuner ([#96827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -63,7 +64,7 @@ def scale_batch_size(
|
|||
__scale_batch_reset_params(trainer, model, steps_per_trial)
|
||||
|
||||
# Save initial model, that is loaded after batch size is found
|
||||
save_path = os.path.join(trainer.default_root_dir, "scale_batch_size_temp_model.ckpt")
|
||||
save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
|
||||
trainer.save_checkpoint(str(save_path))
|
||||
|
||||
if trainer.progress_bar_callback:
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
|
@ -208,7 +209,7 @@ def lr_find(
|
|||
if update_attr:
|
||||
lr_attr_name = _determine_lr_attr_name(trainer, model)
|
||||
|
||||
save_path = os.path.join(trainer.default_root_dir, "lr_find_temp_model.ckpt")
|
||||
save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt")
|
||||
|
||||
__lr_finder_dump_params(trainer, model)
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_model_reset_correctly(tmpdir):
|
|||
torch.eq(before_state_dict[key], after_state_dict[key])
|
||||
), "Model was not reset correctly after learning rate finder"
|
||||
|
||||
assert not os.path.exists(tmpdir / "lr_find_temp_model.ckpt")
|
||||
assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model"))
|
||||
|
||||
|
||||
def test_trainer_reset_correctly(tmpdir):
|
||||
|
|
|
@ -91,6 +91,8 @@ def test_model_reset_correctly(tmpdir):
|
|||
torch.eq(before_state_dict[key], after_state_dict[key])
|
||||
), "Model was not reset correctly after scaling batch size"
|
||||
|
||||
assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model"))
|
||||
|
||||
|
||||
def test_trainer_reset_correctly(tmpdir):
|
||||
"""Check that all trainer parameters are reset correctly after scaling batch size."""
|
||||
|
|
Loading…
Reference in New Issue