Make sure file and folder exists in Profiler (#10073)

Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
twsl 2021-10-26 13:13:31 +02:00 committed by GitHub
parent 84ce1d095c
commit 971281d27d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 0 deletions

View File

@ -653,6 +653,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
- Fixed creation of `dirpath` in `BaseProfiler` if it doesn't exist ([#10073](https://github.com/PyTorchLightning/pytorch-lightning/pull/10073))
## [1.4.9] - 2021-09-30
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

View File

@ -120,6 +120,7 @@ class BaseProfiler(AbstractProfiler):
if self.filename:
filepath = os.path.join(self.dirpath, self._prepare_filename())
fs = get_filesystem(filepath)
fs.mkdirs(self.dirpath, exist_ok=True)
file = fs.open(filepath, "a")
self._output_file = file
self._write_stream = file.write

View File

@ -125,6 +125,43 @@ def test_simple_profiler_log_dir(tmpdir):
assert expected.join("fit-profiler.txt").exists()
def test_simple_profiler_with_nonexisting_log_dir(tmpdir):
"""Ensure the profiler dirpath defaults to `trainer.log_dir`and creates it when not present."""
nonexisting_tmpdir = tmpdir / "nonexisting"
profiler = SimpleProfiler(filename="profiler")
assert profiler._log_dir is None
model = BoringModel()
trainer = Trainer(
default_root_dir=nonexisting_tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler
)
trainer.fit(model)
expected = nonexisting_tmpdir / "lightning_logs" / "version_0"
assert expected.exists()
assert trainer.log_dir == expected
assert profiler._log_dir == trainer.log_dir
assert expected.join("fit-profiler.txt").exists()
def test_simple_profiler_with_nonexisting_dirpath(tmpdir):
"""Ensure the profiler creates non-existing dirpath."""
nonexisting_tmpdir = tmpdir / "nonexisting"
profiler = SimpleProfiler(dirpath=nonexisting_tmpdir, filename="profiler")
assert profiler._log_dir is None
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler
)
trainer.fit(model)
assert nonexisting_tmpdir.exists()
assert nonexisting_tmpdir.join("fit-profiler.txt").exists()
@RunIf(skip_windows=True)
def test_simple_profiler_distributed_files(tmpdir):
"""Ensure the proper files are saved in distributed."""