From 587c195298171c8e27e404da8998fda4415b53b8 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 26 Jul 2019 23:04:41 -0400 Subject: [PATCH] added clean slurm save load test --- tests/test_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 6540204abf..e8fa339b54 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -78,13 +78,12 @@ def test_cpu_slurm_save_load(): # wipe-out trainer and model # retrain with not much data... this simulates picking training back up after slurm # we want to see if the weights come back correctly - continue_tng_hparams = get_hparams(continue_training=True) + continue_tng_hparams = get_hparams(continue_training=True, hpc_exp_number=cluster_a.hpc_exp_number) trainer_options = dict( max_nb_epochs=1, cluster=SlurmCluster(continue_tng_hparams), experiment=exp, checkpoint_callback=ModelCheckpoint(save_dir), - hpc_exp_number=cluster_a.hpc_exp_number ) trainer = Trainer(**trainer_options) model = LightningTestModel(hparams) @@ -584,7 +583,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True): clear_save_dir() -def get_hparams(continue_training=False): +def get_hparams(continue_training=False, hpc_exp_number=0): root_dir = os.path.dirname(os.path.realpath(__file__)) args = { @@ -599,6 +598,7 @@ def get_hparams(continue_training=False): if continue_training: args['test_tube_do_checkpoint_load'] = True + args['hpc_exp_number'] = hpc_exp_number hparams = Namespace(**args) return hparams