Enable gradients at train start (#2200)
* Enable gradients at train start * Update training_loop.py Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
2330d32531
commit
25c7465591
|
@ -319,6 +319,12 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# get model
|
||||
model = self.get_model()
|
||||
|
||||
# enable train mode
|
||||
model.train()
|
||||
|
||||
# enable gradients
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
# load data
|
||||
# if reload_dataloaders_every_epoch, this is moved to the epoch loop
|
||||
if not self.reload_dataloaders_every_epoch:
|
||||
|
|
Loading…
Reference in New Issue