updated args
This commit is contained in:
parent
c4da914747
commit
89410e9090
|
@ -41,6 +41,7 @@ class ExampleModel(RootModule):
|
|||
# TRAINING
|
||||
# ---------------------
|
||||
def forward(self, x):
|
||||
|
||||
x = self.c_d1(x)
|
||||
x = F.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
|
|
|
@ -25,6 +25,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
|||
self.gradient_clip = hparams.gradient_clip
|
||||
self.num = 2
|
||||
self.trainer = None
|
||||
self.from_lightning = True
|
||||
|
||||
# track if gpu was requested for checkpointing
|
||||
self.on_gpu = False
|
||||
|
@ -49,12 +50,13 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
|||
:param x:
|
||||
:return:
|
||||
"""
|
||||
if self.from_lightning:
|
||||
# route the forward call to the correct step type
|
||||
if self.training:
|
||||
return self.training_step(*args, **kwargs)
|
||||
else:
|
||||
return self.validation_step(*args, **kwargs)
|
||||
|
||||
# route the forward call to the correct step type
|
||||
if self.training:
|
||||
return self.training_step(*args, **kwargs)
|
||||
else:
|
||||
return self.validation_step(*args, **kwargs)
|
||||
|
||||
def validation_step(self, data_batch, batch_nb):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue