diff --git a/tests/debug.py b/tests/debug.py index ed8d26cfb4..ca701d713d 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -3,6 +3,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel from argparse import Namespace from test_tube import Experiment +from pytorch_lightning.callbacks import ModelCheckpoint import numpy as np import warnings import torch @@ -27,28 +28,48 @@ def get_model(): return model -def get_exp(): +def get_exp(debug=True): # set up exp object without actually saving logs root_dir = os.path.dirname(os.path.realpath(__file__)) - exp = Experiment(debug=True, save_dir=root_dir, name='tests_tt_dir') + exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir') return exp -def clear_tt_dir(): +def init_save_dir(): root_dir = os.path.dirname(os.path.realpath(__file__)) - tt_dir = os.path.join(root_dir, 'tests_tt_dir') - if os.path.exists(tt_dir): - shutil.rmtree(tt_dir) + save_dir = os.path.join(root_dir, 'save_dir') + + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + + os.makedirs(save_dir, exist_ok=True) + + return save_dir + + +def clear_save_dir(): + root_dir = os.path.dirname(os.path.realpath(__file__)) + save_dir = os.path.join(root_dir, 'save_dir') + if os.path.exists(save_dir): + shutil.rmtree(save_dir) def main(): - clear_tt_dir() + save_dir = init_save_dir() model = get_model() + # exp file to get meta + exp = get_exp(False) + exp.save() + + # exp file to get weights + checkpoint = ModelCheckpoint(save_dir) + trainer = Trainer( + checkpoint_callback=checkpoint, progress_bar=True, - experiment=get_exp(), + experiment=exp, max_nb_epochs=1, train_percent_check=0.1, val_percent_check=0.1, @@ -62,22 +83,12 @@ def main(): # correct result and ok accuracy assert result == 1, 'amp + ddp model failed to complete' - # test prediction - data = model.val_dataloader - for batch in data: - break + # load trained model + pdb.set_trace() + tags_path = exp.get_data_path(exp.name, exp.version) + LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=) - x, y = batch - x = x.view(x.size(0), -1) - out = model(x) - - labels_hat = torch.argmax(out, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - print(val_acc) - - - clear_tt_dir() + clear_save_dir() if __name__ == '__main__':