From f183ac2a1c71dce6c685058a97d26e51b580c112 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 26 Jul 2019 22:51:33 -0400 Subject: [PATCH] added clean slurm save load test --- pytorch_lightning/models/trainer.py | 1 - tests/test_models.py | 34 ++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 44cafb0c56..9c4b7d5fe2 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -161,7 +161,6 @@ class Trainer(TrainerIO): self.nb_tng_batches = None self.nb_test_batches = None - # gpus come in as a string. # if gpus = -1 then use all available devices # otherwise, split the string using commas diff --git a/tests/test_models.py b/tests/test_models.py index 2977f46073..5cee2e5ef6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -75,9 +75,16 @@ def test_cpu_slurm_save_load(): assert os.path.exists(saved_filepath) # wipe-out trainer and model + # retrain with not much data... this simulates picking training back up after slurm # we want to see if the weights come back correctly + continue_tng_hparams = get_hparams(continue_training=True) + trainer_options = dict( + max_nb_epochs=1, + cluster=SlurmCluster(continue_tng_hparams), + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) trainer = Trainer(**trainer_options) - trainer.model = LightningTestModel(hparams) # test HPC loading trainer.hpc_load(save_dir, on_gpu=False) @@ -568,16 +575,23 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True): clear_save_dir() -def get_hparams(): +def get_hparams(continue_training=False): root_dir = os.path.dirname(os.path.realpath(__file__)) - hparams = Namespace(**{'drop_prob': 0.2, - 'batch_size': 32, - 'in_features': 28*28, - 'learning_rate': 0.001*8, - 'optimizer_name': 'adam', - 'data_root': os.path.join(root_dir, 'mnist'), - 'out_features': 10, - 'hidden_dim': 1000}) + + args = { + 'drop_prob': 0.2, + 'batch_size': 32, + 'in_features': 28*28, + 'learning_rate': 0.001*8, + 'optimizer_name': 'adam', + 'data_root': os.path.join(root_dir, 'mnist'), + 'out_features': 10, + 'hidden_dim': 1000} + + if continue_training: + args['test_tube_do_checkpoint_load'] = True + + hparams = Namespace(**args) return hparams