added clean slurm save load test

This commit is contained in:
William Falcon 2019-07-26 23:04:41 -04:00
parent 64586f271d
commit 587c195298
1 changed files with 3 additions and 3 deletions

View File

@ -78,13 +78,12 @@ def test_cpu_slurm_save_load():
# wipe-out trainer and model
# retrain with not much data... this simulates picking training back up after slurm
# we want to see if the weights come back correctly
continue_tng_hparams = get_hparams(continue_training=True)
continue_tng_hparams = get_hparams(continue_training=True, hpc_exp_number=cluster_a.hpc_exp_number)
trainer_options = dict(
max_nb_epochs=1,
cluster=SlurmCluster(continue_tng_hparams),
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir),
hpc_exp_number=cluster_a.hpc_exp_number
)
trainer = Trainer(**trainer_options)
model = LightningTestModel(hparams)
@ -584,7 +583,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
clear_save_dir()
def get_hparams(continue_training=False):
def get_hparams(continue_training=False, hpc_exp_number=0):
root_dir = os.path.dirname(os.path.realpath(__file__))
args = {
@ -599,6 +598,7 @@ def get_hparams(continue_training=False):
if continue_training:
args['test_tube_do_checkpoint_load'] = True
args['hpc_exp_number'] = hpc_exp_number
hparams = Namespace(**args)
return hparams