diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 6cc2ac55c5..5613f86342 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -319,6 +319,12 @@ class TrainerTrainLoopMixin(ABC): # get model model = self.get_model() + # enable train mode + model.train() + + # enable gradients + torch.set_grad_enabled(True) + # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: