From 0a5e75e8d184c32042dac67c15cb4f194c82f6a2 Mon Sep 17 00:00:00 2001 From: Ray Schireman <41241487+rschireman@users.noreply.github.com> Date: Wed, 12 Oct 2022 08:22:01 -0400 Subject: [PATCH] Add `inference_mode` flag to Trainer (#15034) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rohit Gupta Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka Borovec --- docs/source-pytorch/common/trainer.rst | 33 ++++++++++++++++ src/pytorch_lightning/CHANGELOG.md | 5 ++- src/pytorch_lightning/trainer/trainer.py | 17 ++++++-- .../trainer/flags/test_inference_mode.py | 39 +++++++++++++++++++ 4 files changed, 89 insertions(+), 5 deletions(-) create mode 100644 tests/tests_pytorch/trainer/flags/test_inference_mode.py diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 7afabc07d5..e25525b949 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1525,6 +1525,39 @@ Whether to enable or disable the model summarization. Defaults to True. trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)]) + +inference_mode +^^^^^^^^^^^^^^ + +Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation +(``validate``/``test``/``predict``) + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(inference_mode=True) + + # Use `torch.no_grad` instead + trainer = Trainer(inference_mode=False) + + +With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required. + +.. code-block:: python + + class LitModel(LightningModule): + def validation_step(self, batch, batch_idx): + preds = self.layer1(batch) + with torch.enable_grad(): + grad_preds = preds.requires_grad_() + preds2 = self.layer2(grad_preds) + + + model = LitModel() + trainer = Trainer(inference_mode=False) + trainer.validate(model) + + ----- Trainer class API diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ea8a7aa6f6..0d73c9fbe8 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -70,7 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709)) -- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998) +- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)) + + +- Added `inference_mode` flag to Trainer to let users enable/disable inference mode during evaluation ([#15034](https://github.com/Lightning-AI/lightning/pull/15034)) ### Changed diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 190a75ec2b..a385c96cdf 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -162,6 +162,7 @@ class Trainer: amp_level: Optional[str] = None, move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", + inference_mode: bool = True, ) -> None: r""" Customize every aspect of training via flags. @@ -388,6 +389,9 @@ class Trainer: and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. Default: ``"max_size_cycle"``. + + inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during + evaluation (``validate``/``test``/``predict``). """ super().__init__() Trainer._log_api_event("init") @@ -487,6 +491,8 @@ class Trainer: ) self.track_grad_norm: float = float(track_grad_norm) + self._inference_mode: bool = inference_mode + self._detect_anomaly: bool = detect_anomaly self._setup_on_init() @@ -1159,7 +1165,9 @@ class Trainer: # reset trainer on this loop and all child loops in case user connected a custom loop self._evaluation_loop.trainer = self - with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(self.accelerator): + with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context( + self.accelerator, self._inference_mode + ): eval_loop_results = self._evaluation_loop.run() # remove the tensors from the eval results @@ -1175,7 +1183,7 @@ class Trainer: self.reset_predict_dataloader(self.lightning_module) # reset trainer on this loop and all child loops in case user connected a custom loop self.predict_loop.trainer = self - with _evaluation_context(self.accelerator): + with _evaluation_context(self.accelerator, self._inference_mode): return self.predict_loop.run() def _run_sanity_check(self) -> None: @@ -2210,12 +2218,13 @@ class Trainer: @contextmanager -def _evaluation_context(accelerator: Accelerator) -> Generator: +def _evaluation_context(accelerator: Accelerator, inference_mode: bool = True) -> Generator: # inference mode is not supported with gloo backend (#9431), # and HPU & TPU accelerators. context_manager_class = ( torch.inference_mode - if not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo") + if inference_mode + and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo") and not isinstance(accelerator, HPUAccelerator) and not isinstance(accelerator, TPUAccelerator) else torch.no_grad diff --git a/tests/tests_pytorch/trainer/flags/test_inference_mode.py b/tests/tests_pytorch/trainer/flags/test_inference_mode.py new file mode 100644 index 0000000000..3ac65348c3 --- /dev/null +++ b/tests/tests_pytorch/trainer/flags/test_inference_mode.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.demos.boring_classes import BoringModel + + +def test_eval_inference_mode(): + """Testing overwriting trainer arguments.""" + + class BoringModelNoGrad(BoringModel): + def on_test_epoch_start(self) -> None: + assert not torch.is_grad_enabled() + assert not torch.is_inference_mode_enabled() + return super().on_test_epoch_start() + + class BoringModelForInferenceMode(BoringModel): + def on_test_epoch_start(self) -> None: + assert not torch.is_grad_enabled() + assert torch.is_inference_mode_enabled() + return super().on_test_epoch_start() + + trainer = Trainer(logger=False, inference_mode=False, fast_dev_run=True) + trainer.test(BoringModelNoGrad()) + trainer = Trainer(logger=False, inference_mode=True, fast_dev_run=True) + trainer.test(BoringModelForInferenceMode())