updated args
This commit is contained in:
parent
73b4976500
commit
41a935185c
|
@ -164,9 +164,6 @@ class Trainer(TrainerIO):
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# set the model step
|
|
||||||
model.step_split = 'val'
|
|
||||||
|
|
||||||
# disable gradients to save memory
|
# disable gradients to save memory
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
||||||
self.gradient_clip = hparams.gradient_clip
|
self.gradient_clip = hparams.gradient_clip
|
||||||
self.num = 2
|
self.num = 2
|
||||||
self.trainer = None
|
self.trainer = None
|
||||||
self.step_split = 'train'
|
|
||||||
|
|
||||||
# track if gpu was requested for checkpointing
|
# track if gpu was requested for checkpointing
|
||||||
self.on_gpu = False
|
self.on_gpu = False
|
||||||
|
@ -52,12 +51,10 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# route the forward call to the correct step type
|
# route the forward call to the correct step type
|
||||||
if self.step_split == 'train':
|
if self.training:
|
||||||
return self.training_step(*args, **kwargs)
|
return self.training_step(*args, **kwargs)
|
||||||
elif self.step_split == 'val':
|
|
||||||
return self.validation_step(*args, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
return self.validation_step(*args, **kwargs)
|
||||||
|
|
||||||
def validation_step(self, data_batch, batch_nb):
|
def validation_step(self, data_batch, batch_nb):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue