refactored model tests
This commit is contained in:
parent
24ceafa05c
commit
b90841dc3d
|
@ -107,6 +107,8 @@ class LightningTemplateModel(LightningModule):
|
|||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
if self.on_gpu:
|
||||
val_acc = val_acc.cuda(loss_val.device.index)
|
||||
|
||||
|
|
|
@ -105,14 +105,10 @@ def main():
|
|||
checkpoint = ModelCheckpoint(save_dir)
|
||||
|
||||
trainer = Trainer(
|
||||
checkpoint_callback=checkpoint,
|
||||
progress_bar=True,
|
||||
experiment=exp,
|
||||
progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.7,
|
||||
val_percent_check=0.1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp',
|
||||
distributed_backend='dp',
|
||||
use_amp=True
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue