Enable gradients at train start (#2200)

* Enable gradients at train start

* Update training_loop.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Nilansh Rajput 2020-06-17 20:22:58 +05:30 committed by GitHub
parent 2330d32531
commit 25c7465591
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -319,6 +319,12 @@ class TrainerTrainLoopMixin(ABC):
# get model
model = self.get_model()
# enable train mode
model.train()
# enable gradients
torch.set_grad_enabled(True)
# load data
# if reload_dataloaders_every_epoch, this is moved to the epoch loop
if not self.reload_dataloaders_every_epoch: