added test for model loading and predicting
This commit is contained in:
parent
aa90040387
commit
d3651ba15c
|
@ -86,7 +86,7 @@ def run_prediction(dataloader, trained_model):
|
||||||
val_acc = torch.tensor(val_acc)
|
val_acc = torch.tensor(val_acc)
|
||||||
val_acc = val_acc.item()
|
val_acc = val_acc.item()
|
||||||
|
|
||||||
assert val_acc > 0.60, 'this model is expected to get > 0.7 in test set'
|
assert val_acc > 0.60, f'this model is expected to get > 0.7 in test set (it got {val_acc})'
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Reference in New Issue