From 89410e9090af394248f946ca29d8e3a3e892a676 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 19:17:17 -0400 Subject: [PATCH] updated args --- docs/source/examples/example_model.py | 1 + pytorch_lightning/root_module/root_module.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/source/examples/example_model.py b/docs/source/examples/example_model.py index e3917248ca..a8e9db9a54 100644 --- a/docs/source/examples/example_model.py +++ b/docs/source/examples/example_model.py @@ -41,6 +41,7 @@ class ExampleModel(RootModule): # TRAINING # --------------------- def forward(self, x): + x = self.c_d1(x) x = F.tanh(x) x = self.c_d1_bn(x) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 890dc96c43..89a58d030e 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -25,6 +25,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): self.gradient_clip = hparams.gradient_clip self.num = 2 self.trainer = None + self.from_lightning = True # track if gpu was requested for checkpointing self.on_gpu = False @@ -49,12 +50,13 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): :param x: :return: """ + if self.from_lightning: + # route the forward call to the correct step type + if self.training: + return self.training_step(*args, **kwargs) + else: + return self.validation_step(*args, **kwargs) - # route the forward call to the correct step type - if self.training: - return self.training_step(*args, **kwargs) - else: - return self.validation_step(*args, **kwargs) def validation_step(self, data_batch, batch_nb): """