added clean slurm save load test

This commit is contained in:
William Falcon 2019-07-26 23:02:18 -04:00
parent 53b781709e
commit 64586f271d
1 changed files with 4 additions and 2 deletions

View File

@ -40,9 +40,10 @@ def test_cpu_slurm_save_load():
exp.argparse(hparams)
exp.save()
cluster_a = SlurmCluster()
trainer_options = dict(
max_nb_epochs=1,
cluster=SlurmCluster(),
cluster=cluster_a,
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
)
@ -82,7 +83,8 @@ def test_cpu_slurm_save_load():
max_nb_epochs=1,
cluster=SlurmCluster(continue_tng_hparams),
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
checkpoint_callback=ModelCheckpoint(save_dir),
hpc_exp_number=cluster_a.hpc_exp_number
)
trainer = Trainer(**trainer_options)
model = LightningTestModel(hparams)