updated args

This commit is contained in:
William Falcon 2019-06-25 19:17:17 -04:00
parent c4da914747
commit 89410e9090
2 changed files with 8 additions and 5 deletions

View File

@ -41,6 +41,7 @@ class ExampleModel(RootModule):
# TRAINING
# ---------------------
def forward(self, x):
x = self.c_d1(x)
x = F.tanh(x)
x = self.c_d1_bn(x)

View File

@ -25,6 +25,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
self.gradient_clip = hparams.gradient_clip
self.num = 2
self.trainer = None
self.from_lightning = True
# track if gpu was requested for checkpointing
self.on_gpu = False
@ -49,12 +50,13 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
:param x:
:return:
"""
if self.from_lightning:
# route the forward call to the correct step type
if self.training:
return self.training_step(*args, **kwargs)
else:
return self.validation_step(*args, **kwargs)
# route the forward call to the correct step type
if self.training:
return self.training_step(*args, **kwargs)
else:
return self.validation_step(*args, **kwargs)
def validation_step(self, data_batch, batch_nb):
"""