lightning/pytorch_lightning/root_module/hooks.py

73 lines
1.6 KiB
Python

import torch
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
class ModelHooks(torch.nn.Module):
def on_sanity_check_start(self):
"""
Called before starting evaluate
:return:
"""
pass
def on_batch_start(self, batch):
pass
def on_batch_end(self):
pass
def on_epoch_start(self):
pass
def on_epoch_end(self):
pass
def on_pre_performance_check(self):
pass
def on_post_performance_check(self):
pass
def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad()
for optimizer in optimizers:
optimizer.step()
model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad
:param optimizer:
:return:
"""
pass
def on_after_backward(self):
"""
Called after loss.backward() and before optimizers do anything
:return:
"""
pass
def backward(self, use_amp, loss, optimizer):
"""
Override backward with your own implementation if you need to
:param use_amp: Whether amp was requested or not
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:return:
"""
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()