updated args
This commit is contained in:
parent
73b4976500
commit
41a935185c
|
@ -164,9 +164,6 @@ class Trainer(TrainerIO):
|
|||
model.zero_grad()
|
||||
model.eval()
|
||||
|
||||
# set the model step
|
||||
model.step_split = 'val'
|
||||
|
||||
# disable gradients to save memory
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
|||
self.gradient_clip = hparams.gradient_clip
|
||||
self.num = 2
|
||||
self.trainer = None
|
||||
self.step_split = 'train'
|
||||
|
||||
# track if gpu was requested for checkpointing
|
||||
self.on_gpu = False
|
||||
|
@ -52,12 +51,10 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
|||
"""
|
||||
|
||||
# route the forward call to the correct step type
|
||||
if self.step_split == 'train':
|
||||
if self.training:
|
||||
return self.training_step(*args, **kwargs)
|
||||
elif self.step_split == 'val':
|
||||
return self.validation_step(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return self.validation_step(*args, **kwargs)
|
||||
|
||||
def validation_step(self, data_batch, batch_nb):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue