lightning/pytorch_lightning/trainer/training_tricks_mixin.py

29 lines
1.1 KiB
Python
Raw Normal View History

import torch
import logging
from pytorch_lightning.callbacks import GradientAccumulationScheduler
class TrainerTrainingTricksMixin(object):
def clip_gradients(self):
if self.gradient_clip_val > 0:
model = self.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
def print_nan_gradients(self):
model = self.get_model()
for param in model.parameters():
2019-12-04 12:04:58 +00:00
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
logging.info(param, param.grad)
def configure_accumulated_gradients(self, accumulate_grad_batches):
self.accumulate_grad_batches = None
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {1: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")