From 11e289ad9f95f5fe23af147fa4edcc9794f9b9a7 Mon Sep 17 00:00:00 2001 From: mads-oestergaard <104391876+mads-oestergaard@users.noreply.github.com> Date: Mon, 23 May 2022 12:09:47 +0200 Subject: [PATCH] Update trainer profiler typehint to use `Profiler` instead of the deprecated `BaseProfiler` (#13084) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix trainer profiler typehint * Remove unused import of deprecated BaseProfiler * Update CHANGELOG.md Co-authored-by: Akihiro Nitta Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/trainer.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a35fdd2a6..eae507f32b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -226,6 +226,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue wrt unnecessary usage of habana mixed precision package for fp32 types ([#13028](https://github.com/PyTorchLightning/pytorch-lightning/pull/13028)) +- Fixed issue where the CLI could not pass a `Profiler` to the `Trainer` ([#13084](https://github.com/PyTorchLightning/pytorch-lightning/pull/13084)) + + - diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 71cb47b139..d9dc550f10 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -56,7 +56,6 @@ from pytorch_lightning.plugins import ( from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.profiler import ( AdvancedProfiler, - BaseProfiler, PassThroughProfiler, Profiler, PyTorchProfiler, @@ -171,7 +170,7 @@ class Trainer( weights_save_path: Optional[str] = None, # TODO: Remove in 1.8 num_sanity_val_steps: int = 2, resume_from_checkpoint: Optional[Union[Path, str]] = None, - profiler: Optional[Union[BaseProfiler, str]] = None, + profiler: Optional[Union[Profiler, str]] = None, benchmark: Optional[bool] = None, deterministic: Union[bool, _LITERAL_WARN] = False, reload_dataloaders_every_n_epochs: int = 0,