From 4b0b7e5ea3920c1d76ddd409550632c28533e120 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 21 Apr 2019 13:40:32 -0400 Subject: [PATCH] if return -1 from a hook that loop stopps --- pytorch_lightning/models/trainer.py | 2 +- pytorch_lightning/root_module/hooks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 87767d7d59..2a23cf62a1 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -327,7 +327,7 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_start'): - response = self.model.on_batch_start() + response = self.model.on_batch_start(data_batch) if response == -1: return diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py index 0d903f6627..99155ab9a1 100644 --- a/pytorch_lightning/root_module/hooks.py +++ b/pytorch_lightning/root_module/hooks.py @@ -1,7 +1,7 @@ import torch class ModelHooks(torch.nn.Module): - def on_batch_start(self): + def on_batch_start(self, data_batch): pass def on_batch_end(self):