if return -1 from a hook that loop stopps

This commit is contained in:
William Falcon 2019-04-21 13:40:32 -04:00
parent e89da15f18
commit 4b0b7e5ea3
2 changed files with 2 additions and 2 deletions

View File

@ -327,7 +327,7 @@ class Trainer(TrainerIO):
# hook
if self.__is_function_implemented('on_batch_start'):
response = self.model.on_batch_start()
response = self.model.on_batch_start(data_batch)
if response == -1:
return

View File

@ -1,7 +1,7 @@
import torch
class ModelHooks(torch.nn.Module):
def on_batch_start(self):
def on_batch_start(self, data_batch):
pass
def on_batch_end(self):