added test for model loading and predicting

This commit is contained in:
William Falcon 2019-07-24 11:55:22 -04:00
parent aa90040387
commit d3651ba15c
1 changed files with 1 additions and 1 deletions

View File

@ -86,7 +86,7 @@ def run_prediction(dataloader, trained_model):
val_acc = torch.tensor(val_acc) val_acc = torch.tensor(val_acc)
val_acc = val_acc.item() val_acc = val_acc.item()
assert val_acc > 0.60, 'this model is expected to get > 0.7 in test set' assert val_acc > 0.60, f'this model is expected to get > 0.7 in test set (it got {val_acc})'
def main(): def main():