added safeguards for callbacks in loading saving

This commit is contained in:
William Falcon 2019-07-24 11:34:56 -04:00
parent 98c112598e
commit 55a33edd0a
1 changed files with 17 additions and 1 deletions

View File

@ -87,7 +87,23 @@ def main():
pdb.set_trace()
tags_path = exp.get_data_path(exp.name, exp.version)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path)
trained_model = LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path)
# run prediction
for batch in model.test_dataloader:
break
x, y = batch
x = x.view(x.size(0), -1)
y_hat = model(x)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
print(val_acc)
clear_save_dir()