diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index cf7f38707c..1f52128f06 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -15,8 +15,6 @@ import math from collections import OrderedDict from typing import Any, Dict, Optional, Union -import torch - import lightning.pytorch as pl from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher @@ -284,8 +282,7 @@ class _TrainingEpochLoop(loops._Loop): # reload dataloaders self.val_loop._reload_evaluation_dataloaders() - with torch.no_grad(): - self.val_loop.run() + self.val_loop.run() def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2926982ecc..35beb7c57d 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -956,8 +956,7 @@ class Trainer: ] # run eval step - with torch.no_grad(): - val_loop.run() + val_loop.run() call._call_callback_hooks(self, "on_sanity_check_end")