updated args
This commit is contained in:
parent
4d42b1ed5f
commit
73b4976500
|
@ -164,6 +164,9 @@ class Trainer(TrainerIO):
|
|||
model.zero_grad()
|
||||
model.eval()
|
||||
|
||||
# set the model step
|
||||
model.step_split = 'val'
|
||||
|
||||
# disable gradients to save memory
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
@ -183,7 +186,7 @@ class Trainer(TrainerIO):
|
|||
# -----------------
|
||||
# RUN VALIDATION STEP
|
||||
# -----------------
|
||||
output = model.validation_step(data_batch, batch_i)
|
||||
output = model(data_batch, batch_i)
|
||||
outputs.append(output)
|
||||
|
||||
# batch done
|
||||
|
|
|
@ -25,6 +25,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
|||
self.gradient_clip = hparams.gradient_clip
|
||||
self.num = 2
|
||||
self.trainer = None
|
||||
self.step_split = 'train'
|
||||
|
||||
# track if gpu was requested for checkpointing
|
||||
self.on_gpu = False
|
||||
|
@ -49,7 +50,14 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
|||
:param x:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# route the forward call to the correct step type
|
||||
if self.step_split == 'train':
|
||||
return self.training_step(*args, **kwargs)
|
||||
elif self.step_split == 'val':
|
||||
return self.validation_step(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self, data_batch, batch_nb):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue