diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index b208efc080..110bea5acd 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -164,6 +164,9 @@ class Trainer(TrainerIO): model.zero_grad() model.eval() + # set the model step + model.step_split = 'val' + # disable gradients to save memory torch.set_grad_enabled(False) @@ -183,7 +186,7 @@ class Trainer(TrainerIO): # ----------------- # RUN VALIDATION STEP # ----------------- - output = model.validation_step(data_batch, batch_i) + output = model(data_batch, batch_i) outputs.append(output) # batch done diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 467aeeadf5..b9a6cb4d8e 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -25,6 +25,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): self.gradient_clip = hparams.gradient_clip self.num = 2 self.trainer = None + self.step_split = 'train' # track if gpu was requested for checkpointing self.on_gpu = False @@ -49,7 +50,14 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): :param x: :return: """ - raise NotImplementedError + + # route the forward call to the correct step type + if self.step_split == 'train': + return self.training_step(*args, **kwargs) + elif self.step_split == 'val': + return self.validation_step(*args, **kwargs) + else: + raise NotImplementedError def validation_step(self, data_batch, batch_nb): """