refactored model tests

This commit is contained in:
William Falcon 2019-07-24 13:41:28 -04:00
parent 24ceafa05c
commit b90841dc3d
2 changed files with 4 additions and 6 deletions

View File

@ -107,6 +107,8 @@ class LightningTemplateModel(LightningModule):
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
import pdb
pdb.set_trace()
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)

View File

@ -105,14 +105,10 @@ def main():
checkpoint = ModelCheckpoint(save_dir)
trainer = Trainer(
checkpoint_callback=checkpoint,
progress_bar=True,
experiment=exp,
progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.7,
val_percent_check=0.1,
gpus=[0, 1],
distributed_backend='ddp',
distributed_backend='dp',
use_amp=True
)