diff --git a/tests/debug.py b/tests/debug.py index 14d1e56b56..a7dd56aa14 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -66,7 +66,10 @@ def main(): data = model.test_dataloader for batch in data: break - out = model(batch[0]) + + x, y = batch + x = x.view(x.size(0), -1) + out = model(x) print(out)