added clean slurm save load test
This commit is contained in:
parent
64586f271d
commit
587c195298
|
@ -78,13 +78,12 @@ def test_cpu_slurm_save_load():
|
|||
# wipe-out trainer and model
|
||||
# retrain with not much data... this simulates picking training back up after slurm
|
||||
# we want to see if the weights come back correctly
|
||||
continue_tng_hparams = get_hparams(continue_training=True)
|
||||
continue_tng_hparams = get_hparams(continue_training=True, hpc_exp_number=cluster_a.hpc_exp_number)
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
cluster=SlurmCluster(continue_tng_hparams),
|
||||
experiment=exp,
|
||||
checkpoint_callback=ModelCheckpoint(save_dir),
|
||||
hpc_exp_number=cluster_a.hpc_exp_number
|
||||
)
|
||||
trainer = Trainer(**trainer_options)
|
||||
model = LightningTestModel(hparams)
|
||||
|
@ -584,7 +583,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
|||
clear_save_dir()
|
||||
|
||||
|
||||
def get_hparams(continue_training=False):
|
||||
def get_hparams(continue_training=False, hpc_exp_number=0):
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
args = {
|
||||
|
@ -599,6 +598,7 @@ def get_hparams(continue_training=False):
|
|||
|
||||
if continue_training:
|
||||
args['test_tube_do_checkpoint_load'] = True
|
||||
args['hpc_exp_number'] = hpc_exp_number
|
||||
|
||||
hparams = Namespace(**args)
|
||||
return hparams
|
||||
|
|
Loading…
Reference in New Issue