diff --git a/tests/debug.py b/tests/debug.py index 6910f3da2a..9f2959afea 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -16,13 +16,13 @@ def get_model(): # set up model with these hyperparams root_dir = os.path.dirname(os.path.realpath(__file__)) hparams = TTNamespace(**{'drop_prob': 0.2, - 'batch_size': 32, - 'in_features': 28*28, - 'learning_rate': 0.001*8, - 'optimizer_name': 'adam', - 'data_root': os.path.join(root_dir, 'mnist'), - 'out_features': 10, - 'hidden_dim': 1000}) + 'batch_size': 32, + 'in_features': 28*28, + 'learning_rate': 0.001*8, + 'optimizer_name': 'adam', + 'data_root': os.path.join(root_dir, 'mnist'), + 'out_features': 10, + 'hidden_dim': 1000}) model = LightningTemplateModel(hparams) return model, hparams