From 55a33edd0ac45ffef922a83c1b4e820d33a86717 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 11:34:56 -0400 Subject: [PATCH] added safeguards for callbacks in loading saving --- tests/debug.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/debug.py b/tests/debug.py index 6caebd6d5d..a9700356b6 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -87,7 +87,23 @@ def main(): pdb.set_trace() tags_path = exp.get_data_path(exp.name, exp.version) tags_path = os.path.join(tags_path, 'meta_tags.csv') - LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path) + trained_model = LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path) + + # run prediction + for batch in model.test_dataloader: + break + + x, y = batch + x = x.view(x.size(0), -1) + + y_hat = model(x) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + print(val_acc) clear_save_dir()