From 85eaa28872ba0d445d5a231070be173d8d3dd07d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 11:51:38 -0400 Subject: [PATCH] added test for model loading and predicting --- tests/debug.py | 66 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/tests/debug.py b/tests/debug.py index 21761baac7..8fbd4e3c40 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -54,6 +54,41 @@ def clear_save_dir(): shutil.rmtree(save_dir) +def load_model(exp, save_dir): + + # load trained model + tags_path = exp.get_data_path(exp.name, exp.version) + tags_path = os.path.join(tags_path, 'meta_tags.csv') + + checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x] + weights_dir = os.path.join(save_dir, checkpoints[0]) + + trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=True) + + assert trained_model is not None, 'loading model failed' + + return trained_model + + +def run_prediction(dataloader, trained_model): + # run prediction on 1 batch + for batch in dataloader: + break + + x, y = batch + x = x.view(x.size(0), -1) + + y_hat = trained_model(x) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + val_acc = val_acc.item() + + assert val_acc > 0.70, 'this model is expected to get > 0.7 in test set' + + def main(): save_dir = init_save_dir() @@ -72,7 +107,7 @@ def main(): progress_bar=True, experiment=exp, max_nb_epochs=1, - train_percent_check=0.1, + train_percent_check=0.2, val_percent_check=0.1, gpus=[0, 1], distributed_backend='ddp', @@ -84,32 +119,11 @@ def main(): # correct result and ok accuracy assert result == 1, 'amp + ddp model failed to complete' - # load trained model - pdb.set_trace() - tags_path = exp.get_data_path(exp.name, exp.version) - tags_path = os.path.join(tags_path, 'meta_tags.csv') + # test model loading + pretrained_model = load_model(exp, save_dir) - pdb.set_trace() - checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x] - weights_dir = os.path.join(save_dir, checkpoints[0]) - - trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=True) - - # run prediction - for batch in model.test_dataloader: - break - - x, y = batch - x = x.view(x.size(0), -1) - - y_hat = trained_model(x) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - - print(val_acc) + # test model preds + run_prediction(model.test_dataloader, pretrained_model) clear_save_dir()