54 lines
1.0 KiB
Python
54 lines
1.0 KiB
Python
import torch
|
|
|
|
|
|
class ModelHooks(torch.nn.Module):
|
|
|
|
def on_sanity_check_start(self):
|
|
"""
|
|
Called before starting validate
|
|
:return:
|
|
"""
|
|
pass
|
|
|
|
def on_batch_start(self, data_batch):
|
|
pass
|
|
|
|
def on_batch_end(self):
|
|
pass
|
|
|
|
def on_epoch_start(self):
|
|
pass
|
|
|
|
def on_epoch_end(self):
|
|
pass
|
|
|
|
def on_pre_performance_check(self):
|
|
pass
|
|
|
|
def on_post_performance_check(self):
|
|
pass
|
|
|
|
def on_tng_metrics(self, metrics):
|
|
pass
|
|
|
|
def on_before_zero_grad(self, optimizer):
|
|
"""
|
|
Called after optimizer.step() and before optimizer.zero_grad()
|
|
|
|
for optimizer in optimizers:
|
|
optimizer.step()
|
|
model.on_before_zero_grad(optimizer) # < ---- called here
|
|
optimizer.zero_grad
|
|
|
|
:param optimizer:
|
|
:return:
|
|
"""
|
|
pass
|
|
|
|
def on_after_backward(self):
|
|
"""
|
|
Called after loss.backward() and before optimizers do anything
|
|
:return:
|
|
"""
|
|
pass
|