if return -1 from a hook that loop stopps
This commit is contained in:
parent
e89da15f18
commit
4b0b7e5ea3
|
@ -327,7 +327,7 @@ class Trainer(TrainerIO):
|
||||||
|
|
||||||
# hook
|
# hook
|
||||||
if self.__is_function_implemented('on_batch_start'):
|
if self.__is_function_implemented('on_batch_start'):
|
||||||
response = self.model.on_batch_start()
|
response = self.model.on_batch_start(data_batch)
|
||||||
if response == -1:
|
if response == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class ModelHooks(torch.nn.Module):
|
class ModelHooks(torch.nn.Module):
|
||||||
def on_batch_start(self):
|
def on_batch_start(self, data_batch):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_batch_end(self):
|
def on_batch_end(self):
|
||||||
|
|
Loading…
Reference in New Issue