parent
f9c9e39ab8
commit
2ec8d61e94
|
@ -100,7 +100,7 @@ To also add a validation loop add the following functions
|
|||
def validation_epoch_end(self, outputs):
|
||||
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||
tensorboard_logs = {'val_loss': avg_loss}
|
||||
return {'val_loss': avg_loss, 'log': tensorboard_logs
|
||||
return {'val_loss': avg_loss, 'log': tensorboard_logs}
|
||||
|
||||
def val_dataloader(self):
|
||||
# TODO: do a real train/val split
|
||||
|
|
Loading…
Reference in New Issue