updated args

This commit is contained in:
William Falcon 2019-06-25 19:06:19 -04:00
parent 73b4976500
commit 41a935185c
2 changed files with 2 additions and 8 deletions

View File

@ -164,9 +164,6 @@ class Trainer(TrainerIO):
model.zero_grad() model.zero_grad()
model.eval() model.eval()
# set the model step
model.step_split = 'val'
# disable gradients to save memory # disable gradients to save memory
torch.set_grad_enabled(False) torch.set_grad_enabled(False)

View File

@ -25,7 +25,6 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
self.gradient_clip = hparams.gradient_clip self.gradient_clip = hparams.gradient_clip
self.num = 2 self.num = 2
self.trainer = None self.trainer = None
self.step_split = 'train'
# track if gpu was requested for checkpointing # track if gpu was requested for checkpointing
self.on_gpu = False self.on_gpu = False
@ -52,12 +51,10 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
""" """
# route the forward call to the correct step type # route the forward call to the correct step type
if self.step_split == 'train': if self.training:
return self.training_step(*args, **kwargs) return self.training_step(*args, **kwargs)
elif self.step_split == 'val':
return self.validation_step(*args, **kwargs)
else: else:
raise NotImplementedError return self.validation_step(*args, **kwargs)
def validation_step(self, data_batch, batch_nb): def validation_step(self, data_batch, batch_nb):
""" """