updated args

This commit is contained in:
William Falcon 2019-06-25 19:04:49 -04:00
parent 4d42b1ed5f
commit 73b4976500
2 changed files with 13 additions and 2 deletions

View File

@ -164,6 +164,9 @@ class Trainer(TrainerIO):
model.zero_grad()
model.eval()
# set the model step
model.step_split = 'val'
# disable gradients to save memory
torch.set_grad_enabled(False)
@ -183,7 +186,7 @@ class Trainer(TrainerIO):
# -----------------
# RUN VALIDATION STEP
# -----------------
output = model.validation_step(data_batch, batch_i)
output = model(data_batch, batch_i)
outputs.append(output)
# batch done

View File

@ -25,6 +25,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
self.gradient_clip = hparams.gradient_clip
self.num = 2
self.trainer = None
self.step_split = 'train'
# track if gpu was requested for checkpointing
self.on_gpu = False
@ -49,7 +50,14 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
:param x:
:return:
"""
raise NotImplementedError
# route the forward call to the correct step type
if self.step_split == 'train':
return self.training_step(*args, **kwargs)
elif self.step_split == 'val':
return self.validation_step(*args, **kwargs)
else:
raise NotImplementedError
def validation_step(self, data_batch, batch_nb):
"""