From 41a935185c792b24f83b2efe77ceed6aa7e510b0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 19:06:19 -0400 Subject: [PATCH] updated args --- pytorch_lightning/models/trainer.py | 3 --- pytorch_lightning/root_module/root_module.py | 7 ++----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 110bea5acd..8e89863e39 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -164,9 +164,6 @@ class Trainer(TrainerIO): model.zero_grad() model.eval() - # set the model step - model.step_split = 'val' - # disable gradients to save memory torch.set_grad_enabled(False) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index b9a6cb4d8e..890dc96c43 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -25,7 +25,6 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): self.gradient_clip = hparams.gradient_clip self.num = 2 self.trainer = None - self.step_split = 'train' # track if gpu was requested for checkpointing self.on_gpu = False @@ -52,12 +51,10 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): """ # route the forward call to the correct step type - if self.step_split == 'train': + if self.training: return self.training_step(*args, **kwargs) - elif self.step_split == 'val': - return self.validation_step(*args, **kwargs) else: - raise NotImplementedError + return self.validation_step(*args, **kwargs) def validation_step(self, data_batch, batch_nb): """