added clean slurm save load test

This commit is contained in:
William Falcon 2019-07-26 22:51:33 -04:00
parent 61c82611eb
commit f183ac2a1c
2 changed files with 24 additions and 11 deletions

View File

@ -161,7 +161,6 @@ class Trainer(TrainerIO):
self.nb_tng_batches = None self.nb_tng_batches = None
self.nb_test_batches = None self.nb_test_batches = None
# gpus come in as a string. # gpus come in as a string.
# if gpus = -1 then use all available devices # if gpus = -1 then use all available devices
# otherwise, split the string using commas # otherwise, split the string using commas

View File

@ -75,9 +75,16 @@ def test_cpu_slurm_save_load():
assert os.path.exists(saved_filepath) assert os.path.exists(saved_filepath)
# wipe-out trainer and model # wipe-out trainer and model
# retrain with not much data... this simulates picking training back up after slurm
# we want to see if the weights come back correctly # we want to see if the weights come back correctly
continue_tng_hparams = get_hparams(continue_training=True)
trainer_options = dict(
max_nb_epochs=1,
cluster=SlurmCluster(continue_tng_hparams),
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
)
trainer = Trainer(**trainer_options) trainer = Trainer(**trainer_options)
trainer.model = LightningTestModel(hparams)
# test HPC loading # test HPC loading
trainer.hpc_load(save_dir, on_gpu=False) trainer.hpc_load(save_dir, on_gpu=False)
@ -568,16 +575,23 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
clear_save_dir() clear_save_dir()
def get_hparams(): def get_hparams(continue_training=False):
root_dir = os.path.dirname(os.path.realpath(__file__)) root_dir = os.path.dirname(os.path.realpath(__file__))
hparams = Namespace(**{'drop_prob': 0.2,
'batch_size': 32, args = {
'in_features': 28*28, 'drop_prob': 0.2,
'learning_rate': 0.001*8, 'batch_size': 32,
'optimizer_name': 'adam', 'in_features': 28*28,
'data_root': os.path.join(root_dir, 'mnist'), 'learning_rate': 0.001*8,
'out_features': 10, 'optimizer_name': 'adam',
'hidden_dim': 1000}) 'data_root': os.path.join(root_dir, 'mnist'),
'out_features': 10,
'hidden_dim': 1000}
if continue_training:
args['test_tube_do_checkpoint_load'] = True
hparams = Namespace(**args)
return hparams return hparams