diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 3d242adadc..2537785175 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -86,7 +86,7 @@ class TrainerIO(object): # -------------------- # HPC IO # -------------------- - def enable_auto_hpc_walltime_manager(self): # pragma: no cover + def enable_auto_hpc_walltime_manager(self): if self.cluster is None: return @@ -157,6 +157,8 @@ class TrainerIO(object): # do the actual save torch.save(checkpoint_dict, filepath) + return filepath + def hpc_load(self, folderpath, on_gpu): filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath)) diff --git a/tests/test_models.py b/tests/test_models.py index c20c927d6a..e5781d3211 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,16 +3,18 @@ from pytorch_lightning import Trainer from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel from pytorch_lightning.testing_models.lm_test_module import LightningTestModel from argparse import Namespace -from test_tube import Experiment +from test_tube import Experiment, SlurmCluster from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.utils.debugging import MisconfigurationException from pytorch_lightning.root_module import memory from pytorch_lightning.models.trainer import reduce_distributed_output +from pytorch_lightning.root_module import model_saving import numpy as np import warnings import torch import os import shutil +import pdb SEED = 2334 torch.manual_seed(SEED) @@ -22,6 +24,25 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ +def test_loading_meta_tags(): + hparams = get_hparams() + + save_dir = init_save_dir() + + # save tags + exp = get_exp(False) + exp.tag({'some_str':'a_str', 'an_int': 1, 'a_float': 2.0}) + exp.argparse(hparams) + exp.save() + + # load tags + tags_path = exp.get_data_path(exp.name, exp.version) + '/meta_tags.csv' + tags = model_saving.load_hparams_from_tags_csv(tags_path) + + assert tags.batch_size == 32 and tags.hidden_dim == 1000 + + clear_save_dir() + def test_dp_output_reduce(): # test identity when we have a single gpu @@ -43,6 +64,68 @@ def test_dp_output_reduce(): assert reduced['b']['c'] == out['b']['c'] +def test_cpu_slurm_saving_loading(): + """ + Verify model save/load/checkpoint on CPU + :return: + """ + hparams = get_hparams() + model = LightningTestModel(hparams) + + save_dir = init_save_dir() + + # exp file to get meta + exp = get_exp(False) + exp.argparse(hparams) + exp.save() + + trainer_options = dict( + max_nb_epochs=1, + cluster=SlurmCluster(), + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + real_global_step = trainer.global_step + + # traning complete + assert result == 1, 'amp + ddp model failed to complete' + + # test saving checkpoint + ckpt_test = os.path.join(save_dir, 'test.ckpt') + trainer.save_checkpoint(ckpt_test) + + # test registering a save function + trainer.enable_auto_hpc_walltime_manager() + + # test model loading with a map_location + pretrained_model = load_model(exp, save_dir, True) + + # test model preds + run_prediction(model.test_dataloader, pretrained_model) + + trainer.model = pretrained_model + trainer.optimizers = pretrained_model.configure_optimizers() + + # test HPC saving + saved_filepath = trainer.hpc_save(save_dir, exp) + assert os.path.exists(saved_filepath) + + # test HPC loading + trainer.global_step = 20000000 + trainer.hpc_load(save_dir, on_gpu=False) + assert trainer.global_step == real_global_step and trainer.global_step != 20000000 + + # test freeze on gpu + model.freeze() + model.unfreeze() + + clear_save_dir() + + def test_amp_gpu_ddp_slurm_managed(): """ Make sure DDP + AMP work