Add `inference_mode` flag to Trainer (#15034)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Ray Schireman 2022-10-12 08:22:01 -04:00 committed by GitHub
parent ad1e06f2d4
commit 0a5e75e8d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 5 deletions

View File

@ -1525,6 +1525,39 @@ Whether to enable or disable the model summarization. Defaults to True.
trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
inference_mode
^^^^^^^^^^^^^^
Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation
(``validate``/``test``/``predict``)
.. testcode::
# default used by the Trainer
trainer = Trainer(inference_mode=True)
# Use `torch.no_grad` instead
trainer = Trainer(inference_mode=False)
With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required.
.. code-block:: python
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
preds = self.layer1(batch)
with torch.enable_grad():
grad_preds = preds.requires_grad_()
preds2 = self.layer2(grad_preds)
model = LitModel()
trainer = Trainer(inference_mode=False)
trainer.validate(model)
-----
Trainer class API

View File

@ -70,7 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))
- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)
- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998))
- Added `inference_mode` flag to Trainer to let users enable/disable inference mode during evaluation ([#15034](https://github.com/Lightning-AI/lightning/pull/15034))
### Changed

View File

@ -162,6 +162,7 @@ class Trainer:
amp_level: Optional[str] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
inference_mode: bool = True,
) -> None:
r"""
Customize every aspect of training via flags.
@ -388,6 +389,9 @@ class Trainer:
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets.
Default: ``"max_size_cycle"``.
inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
evaluation (``validate``/``test``/``predict``).
"""
super().__init__()
Trainer._log_api_event("init")
@ -487,6 +491,8 @@ class Trainer:
)
self.track_grad_norm: float = float(track_grad_norm)
self._inference_mode: bool = inference_mode
self._detect_anomaly: bool = detect_anomaly
self._setup_on_init()
@ -1159,7 +1165,9 @@ class Trainer:
# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(self.accelerator):
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
self.accelerator, self._inference_mode
):
eval_loop_results = self._evaluation_loop.run()
# remove the tensors from the eval results
@ -1175,7 +1183,7 @@ class Trainer:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with _evaluation_context(self.accelerator):
with _evaluation_context(self.accelerator, self._inference_mode):
return self.predict_loop.run()
def _run_sanity_check(self) -> None:
@ -2210,12 +2218,13 @@ class Trainer:
@contextmanager
def _evaluation_context(accelerator: Accelerator) -> Generator:
def _evaluation_context(accelerator: Accelerator, inference_mode: bool = True) -> Generator:
# inference mode is not supported with gloo backend (#9431),
# and HPU & TPU accelerators.
context_manager_class = (
torch.inference_mode
if not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
if inference_mode
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
and not isinstance(accelerator, HPUAccelerator)
and not isinstance(accelerator, TPUAccelerator)
else torch.no_grad

View File

@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
def test_eval_inference_mode():
"""Testing overwriting trainer arguments."""
class BoringModelNoGrad(BoringModel):
def on_test_epoch_start(self) -> None:
assert not torch.is_grad_enabled()
assert not torch.is_inference_mode_enabled()
return super().on_test_epoch_start()
class BoringModelForInferenceMode(BoringModel):
def on_test_epoch_start(self) -> None:
assert not torch.is_grad_enabled()
assert torch.is_inference_mode_enabled()
return super().on_test_epoch_start()
trainer = Trainer(logger=False, inference_mode=False, fast_dev_run=True)
trainer.test(BoringModelNoGrad())
trainer = Trainer(logger=False, inference_mode=True, fast_dev_run=True)
trainer.test(BoringModelForInferenceMode())