diff --git a/tests/debug.py b/tests/debug.py index 1375930d91..49483e0bbb 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -62,6 +62,14 @@ def main(): # correct result and ok accuracy assert result == 1, 'amp + ddp model failed to complete' + # test prediction + data = model.test_dataloader + for batch in data: + break + out = model(data[0]) + print(out) + + clear_tt_dir()