added test for model loading and predicting
This commit is contained in:
parent
d3651ba15c
commit
8781d8aeab
|
@ -86,6 +86,8 @@ def run_prediction(dataloader, trained_model):
|
|||
val_acc = torch.tensor(val_acc)
|
||||
val_acc = val_acc.item()
|
||||
|
||||
print(val_acc)
|
||||
|
||||
assert val_acc > 0.60, f'this model is expected to get > 0.7 in test set (it got {val_acc})'
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue