From 971281d27d0e6c05c9cdcfdb801ed26382620cb6 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Tue, 26 Oct 2021 13:13:31 +0200 Subject: [PATCH] Make sure file and folder exists in Profiler (#10073) Co-authored-by: tchaton --- CHANGELOG.md | 3 +++ pytorch_lightning/profiler/base.py | 1 + tests/profiler/test_profiler.py | 37 ++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89e0290fc2..bdc2497c56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -653,6 +653,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764)) +- Fixed creation of `dirpath` in `BaseProfiler` if it doesn't exist ([#10073](https://github.com/PyTorchLightning/pytorch-lightning/pull/10073)) + + ## [1.4.9] - 2021-09-30 - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704)) diff --git a/pytorch_lightning/profiler/base.py b/pytorch_lightning/profiler/base.py index f14082bfd8..f5df2b5382 100644 --- a/pytorch_lightning/profiler/base.py +++ b/pytorch_lightning/profiler/base.py @@ -120,6 +120,7 @@ class BaseProfiler(AbstractProfiler): if self.filename: filepath = os.path.join(self.dirpath, self._prepare_filename()) fs = get_filesystem(filepath) + fs.mkdirs(self.dirpath, exist_ok=True) file = fs.open(filepath, "a") self._output_file = file self._write_stream = file.write diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 36ef565d03..7369ab9a4a 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -125,6 +125,43 @@ def test_simple_profiler_log_dir(tmpdir): assert expected.join("fit-profiler.txt").exists() +def test_simple_profiler_with_nonexisting_log_dir(tmpdir): + """Ensure the profiler dirpath defaults to `trainer.log_dir`and creates it when not present.""" + nonexisting_tmpdir = tmpdir / "nonexisting" + + profiler = SimpleProfiler(filename="profiler") + assert profiler._log_dir is None + + model = BoringModel() + trainer = Trainer( + default_root_dir=nonexisting_tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler + ) + trainer.fit(model) + + expected = nonexisting_tmpdir / "lightning_logs" / "version_0" + assert expected.exists() + assert trainer.log_dir == expected + assert profiler._log_dir == trainer.log_dir + assert expected.join("fit-profiler.txt").exists() + + +def test_simple_profiler_with_nonexisting_dirpath(tmpdir): + """Ensure the profiler creates non-existing dirpath.""" + nonexisting_tmpdir = tmpdir / "nonexisting" + + profiler = SimpleProfiler(dirpath=nonexisting_tmpdir, filename="profiler") + assert profiler._log_dir is None + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler + ) + trainer.fit(model) + + assert nonexisting_tmpdir.exists() + assert nonexisting_tmpdir.join("fit-profiler.txt").exists() + + @RunIf(skip_windows=True) def test_simple_profiler_distributed_files(tmpdir): """Ensure the proper files are saved in distributed."""