import torch class ModelHooks(torch.nn.Module): def on_batch_start(self, data_batch): pass def on_batch_end(self): pass def on_epoch_start(self): pass def on_epoch_end(self): pass def on_pre_performance_check(self): pass def on_post_performance_check(self): pass def on_tng_metrics(self, metrics): pass