diff --git a/tests/debug.py b/tests/debug.py index 7f97d86dfe..c0d845c50e 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -96,7 +96,7 @@ def main(): x, y = batch x = x.view(x.size(0), -1) - y_hat = model(x) + y_hat = trained_model(x) # acc labels_hat = torch.argmax(y_hat, dim=1)