added clean slurm save load test
This commit is contained in:
parent
61c82611eb
commit
f183ac2a1c
|
@ -161,7 +161,6 @@ class Trainer(TrainerIO):
|
||||||
self.nb_tng_batches = None
|
self.nb_tng_batches = None
|
||||||
self.nb_test_batches = None
|
self.nb_test_batches = None
|
||||||
|
|
||||||
|
|
||||||
# gpus come in as a string.
|
# gpus come in as a string.
|
||||||
# if gpus = -1 then use all available devices
|
# if gpus = -1 then use all available devices
|
||||||
# otherwise, split the string using commas
|
# otherwise, split the string using commas
|
||||||
|
|
|
@ -75,9 +75,16 @@ def test_cpu_slurm_save_load():
|
||||||
assert os.path.exists(saved_filepath)
|
assert os.path.exists(saved_filepath)
|
||||||
|
|
||||||
# wipe-out trainer and model
|
# wipe-out trainer and model
|
||||||
|
# retrain with not much data... this simulates picking training back up after slurm
|
||||||
# we want to see if the weights come back correctly
|
# we want to see if the weights come back correctly
|
||||||
|
continue_tng_hparams = get_hparams(continue_training=True)
|
||||||
|
trainer_options = dict(
|
||||||
|
max_nb_epochs=1,
|
||||||
|
cluster=SlurmCluster(continue_tng_hparams),
|
||||||
|
experiment=exp,
|
||||||
|
checkpoint_callback=ModelCheckpoint(save_dir)
|
||||||
|
)
|
||||||
trainer = Trainer(**trainer_options)
|
trainer = Trainer(**trainer_options)
|
||||||
trainer.model = LightningTestModel(hparams)
|
|
||||||
|
|
||||||
# test HPC loading
|
# test HPC loading
|
||||||
trainer.hpc_load(save_dir, on_gpu=False)
|
trainer.hpc_load(save_dir, on_gpu=False)
|
||||||
|
@ -568,16 +575,23 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
||||||
clear_save_dir()
|
clear_save_dir()
|
||||||
|
|
||||||
|
|
||||||
def get_hparams():
|
def get_hparams(continue_training=False):
|
||||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
hparams = Namespace(**{'drop_prob': 0.2,
|
|
||||||
'batch_size': 32,
|
args = {
|
||||||
'in_features': 28*28,
|
'drop_prob': 0.2,
|
||||||
'learning_rate': 0.001*8,
|
'batch_size': 32,
|
||||||
'optimizer_name': 'adam',
|
'in_features': 28*28,
|
||||||
'data_root': os.path.join(root_dir, 'mnist'),
|
'learning_rate': 0.001*8,
|
||||||
'out_features': 10,
|
'optimizer_name': 'adam',
|
||||||
'hidden_dim': 1000})
|
'data_root': os.path.join(root_dir, 'mnist'),
|
||||||
|
'out_features': 10,
|
||||||
|
'hidden_dim': 1000}
|
||||||
|
|
||||||
|
if continue_training:
|
||||||
|
args['test_tube_do_checkpoint_load'] = True
|
||||||
|
|
||||||
|
hparams = Namespace(**args)
|
||||||
return hparams
|
return hparams
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue