Add `inference_mode` flag to Trainer (#15034)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
ad1e06f2d4
commit
0a5e75e8d1
|
@ -1525,6 +1525,39 @@ Whether to enable or disable the model summarization. Defaults to True.
|
|||
|
||||
trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
|
||||
|
||||
|
||||
inference_mode
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation
|
||||
(``validate``/``test``/``predict``)
|
||||
|
||||
.. testcode::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(inference_mode=True)
|
||||
|
||||
# Use `torch.no_grad` instead
|
||||
trainer = Trainer(inference_mode=False)
|
||||
|
||||
|
||||
With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LitModel(LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
preds = self.layer1(batch)
|
||||
with torch.enable_grad():
|
||||
grad_preds = preds.requires_grad_()
|
||||
preds2 = self.layer2(grad_preds)
|
||||
|
||||
|
||||
model = LitModel()
|
||||
trainer = Trainer(inference_mode=False)
|
||||
trainer.validate(model)
|
||||
|
||||
|
||||
-----
|
||||
|
||||
Trainer class API
|
||||
|
|
|
@ -70,7 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))
|
||||
|
||||
|
||||
- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)
|
||||
- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998))
|
||||
|
||||
|
||||
- Added `inference_mode` flag to Trainer to let users enable/disable inference mode during evaluation ([#15034](https://github.com/Lightning-AI/lightning/pull/15034))
|
||||
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -162,6 +162,7 @@ class Trainer:
|
|||
amp_level: Optional[str] = None,
|
||||
move_metrics_to_cpu: bool = False,
|
||||
multiple_trainloader_mode: str = "max_size_cycle",
|
||||
inference_mode: bool = True,
|
||||
) -> None:
|
||||
r"""
|
||||
Customize every aspect of training via flags.
|
||||
|
@ -388,6 +389,9 @@ class Trainer:
|
|||
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
|
||||
reload when reaching the minimum length of datasets.
|
||||
Default: ``"max_size_cycle"``.
|
||||
|
||||
inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
|
||||
evaluation (``validate``/``test``/``predict``).
|
||||
"""
|
||||
super().__init__()
|
||||
Trainer._log_api_event("init")
|
||||
|
@ -487,6 +491,8 @@ class Trainer:
|
|||
)
|
||||
self.track_grad_norm: float = float(track_grad_norm)
|
||||
|
||||
self._inference_mode: bool = inference_mode
|
||||
|
||||
self._detect_anomaly: bool = detect_anomaly
|
||||
self._setup_on_init()
|
||||
|
||||
|
@ -1159,7 +1165,9 @@ class Trainer:
|
|||
# reset trainer on this loop and all child loops in case user connected a custom loop
|
||||
self._evaluation_loop.trainer = self
|
||||
|
||||
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(self.accelerator):
|
||||
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
|
||||
self.accelerator, self._inference_mode
|
||||
):
|
||||
eval_loop_results = self._evaluation_loop.run()
|
||||
|
||||
# remove the tensors from the eval results
|
||||
|
@ -1175,7 +1183,7 @@ class Trainer:
|
|||
self.reset_predict_dataloader(self.lightning_module)
|
||||
# reset trainer on this loop and all child loops in case user connected a custom loop
|
||||
self.predict_loop.trainer = self
|
||||
with _evaluation_context(self.accelerator):
|
||||
with _evaluation_context(self.accelerator, self._inference_mode):
|
||||
return self.predict_loop.run()
|
||||
|
||||
def _run_sanity_check(self) -> None:
|
||||
|
@ -2210,12 +2218,13 @@ class Trainer:
|
|||
|
||||
|
||||
@contextmanager
|
||||
def _evaluation_context(accelerator: Accelerator) -> Generator:
|
||||
def _evaluation_context(accelerator: Accelerator, inference_mode: bool = True) -> Generator:
|
||||
# inference mode is not supported with gloo backend (#9431),
|
||||
# and HPU & TPU accelerators.
|
||||
context_manager_class = (
|
||||
torch.inference_mode
|
||||
if not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
|
||||
if inference_mode
|
||||
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
|
||||
and not isinstance(accelerator, HPUAccelerator)
|
||||
and not isinstance(accelerator, TPUAccelerator)
|
||||
else torch.no_grad
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
|
||||
|
||||
def test_eval_inference_mode():
|
||||
"""Testing overwriting trainer arguments."""
|
||||
|
||||
class BoringModelNoGrad(BoringModel):
|
||||
def on_test_epoch_start(self) -> None:
|
||||
assert not torch.is_grad_enabled()
|
||||
assert not torch.is_inference_mode_enabled()
|
||||
return super().on_test_epoch_start()
|
||||
|
||||
class BoringModelForInferenceMode(BoringModel):
|
||||
def on_test_epoch_start(self) -> None:
|
||||
assert not torch.is_grad_enabled()
|
||||
assert torch.is_inference_mode_enabled()
|
||||
return super().on_test_epoch_start()
|
||||
|
||||
trainer = Trainer(logger=False, inference_mode=False, fast_dev_run=True)
|
||||
trainer.test(BoringModelNoGrad())
|
||||
trainer = Trainer(logger=False, inference_mode=True, fast_dev_run=True)
|
||||
trainer.test(BoringModelForInferenceMode())
|
Loading…
Reference in New Issue