diff --git a/tests/debug.py b/tests/debug.py index a7dd56aa14..6b27fd33c7 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -70,7 +70,11 @@ def main(): x, y = batch x = x.view(x.size(0), -1) out = model(x) - print(out) + + labels_hat = torch.argmax(out, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + print(val_acc) clear_tt_dir()