diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 87767d7d59..2a23cf62a1 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -327,7 +327,7 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_start'): - response = self.model.on_batch_start() + response = self.model.on_batch_start(data_batch) if response == -1: return diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py index 0d903f6627..99155ab9a1 100644 --- a/pytorch_lightning/root_module/hooks.py +++ b/pytorch_lightning/root_module/hooks.py @@ -1,7 +1,7 @@ import torch class ModelHooks(torch.nn.Module): - def on_batch_start(self): + def on_batch_start(self, data_batch): pass def on_batch_end(self):