diff --git a/tests/debug.py b/tests/debug.py index 49483e0bbb..14d1e56b56 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -66,7 +66,7 @@ def main(): data = model.test_dataloader for batch in data: break - out = model(data[0]) + out = model(batch[0]) print(out)