2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
|
|
|
|
2019-08-04 18:08:14 +00:00
|
|
|
|
2019-10-24 11:56:56 +00:00
|
|
|
try:
|
|
|
|
from apex import amp
|
|
|
|
|
|
|
|
APEX_AVAILABLE = True
|
|
|
|
except ImportError:
|
|
|
|
APEX_AVAILABLE = False
|
|
|
|
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
class ModelHooks(torch.nn.Module):
|
2019-08-07 11:51:55 +00:00
|
|
|
|
|
|
|
def on_sanity_check_start(self):
|
|
|
|
"""
|
2019-08-30 22:56:09 +00:00
|
|
|
Called before starting evaluate
|
2019-08-07 11:51:55 +00:00
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
def on_batch_start(self, batch):
|
2019-03-31 01:45:16 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
def on_batch_end(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def on_epoch_start(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def on_epoch_end(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def on_pre_performance_check(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def on_post_performance_check(self):
|
|
|
|
pass
|
2019-04-21 16:26:35 +00:00
|
|
|
|
2019-07-21 22:15:58 +00:00
|
|
|
def on_before_zero_grad(self, optimizer):
|
|
|
|
"""
|
|
|
|
Called after optimizer.step() and before optimizer.zero_grad()
|
|
|
|
|
|
|
|
for optimizer in optimizers:
|
|
|
|
optimizer.step()
|
|
|
|
model.on_before_zero_grad(optimizer) # < ---- called here
|
|
|
|
optimizer.zero_grad
|
|
|
|
|
|
|
|
:param optimizer:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2019-07-21 22:23:48 +00:00
|
|
|
def on_after_backward(self):
|
|
|
|
"""
|
|
|
|
Called after loss.backward() and before optimizers do anything
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
pass
|
2019-10-24 11:56:56 +00:00
|
|
|
|
|
|
|
def backward(self, use_amp, loss, optimizer):
|
|
|
|
"""
|
|
|
|
Override backward with your own implementation if you need to
|
|
|
|
:param use_amp: Whether amp was requested or not
|
|
|
|
:param loss: Loss is already scaled by accumulated grads
|
|
|
|
:param optimizer: Current optimizer being used
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
if use_amp:
|
|
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
|
scaled_loss.backward()
|
|
|
|
else:
|
|
|
|
loss.backward()
|