Merge pull request #20 from williamFalcon/test2

Test2
This commit is contained in:
William Falcon 2019-07-26 12:47:13 -04:00 committed by GitHub
commit 56d41eaa8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 2 deletions

View File

@ -86,7 +86,7 @@ class TrainerIO(object):
# --------------------
# HPC IO
# --------------------
def enable_auto_hpc_walltime_manager(self): # pragma: no cover
def enable_auto_hpc_walltime_manager(self):
if self.cluster is None:
return
@ -157,6 +157,8 @@ class TrainerIO(object):
# do the actual save
torch.save(checkpoint_dict, filepath)
return filepath
def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))

View File

@ -3,16 +3,18 @@ from pytorch_lightning import Trainer
from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel
from pytorch_lightning.testing_models.lm_test_module import LightningTestModel
from argparse import Namespace
from test_tube import Experiment
from test_tube import Experiment, SlurmCluster
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utils.debugging import MisconfigurationException
from pytorch_lightning.root_module import memory
from pytorch_lightning.models.trainer import reduce_distributed_output
from pytorch_lightning.root_module import model_saving
import numpy as np
import warnings
import torch
import os
import shutil
import pdb
SEED = 2334
torch.manual_seed(SEED)
@ -22,6 +24,25 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------
# TESTS
# ------------------------------------------------------------------------
def test_loading_meta_tags():
hparams = get_hparams()
save_dir = init_save_dir()
# save tags
exp = get_exp(False)
exp.tag({'some_str':'a_str', 'an_int': 1, 'a_float': 2.0})
exp.argparse(hparams)
exp.save()
# load tags
tags_path = exp.get_data_path(exp.name, exp.version) + '/meta_tags.csv'
tags = model_saving.load_hparams_from_tags_csv(tags_path)
assert tags.batch_size == 32 and tags.hidden_dim == 1000
clear_save_dir()
def test_dp_output_reduce():
# test identity when we have a single gpu
@ -43,6 +64,68 @@ def test_dp_output_reduce():
assert reduced['b']['c'] == out['b']['c']
def test_cpu_slurm_saving_loading():
"""
Verify model save/load/checkpoint on CPU
:return:
"""
hparams = get_hparams()
model = LightningTestModel(hparams)
save_dir = init_save_dir()
# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()
trainer_options = dict(
max_nb_epochs=1,
cluster=SlurmCluster(),
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
real_global_step = trainer.global_step
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
# test saving checkpoint
ckpt_test = os.path.join(save_dir, 'test.ckpt')
trainer.save_checkpoint(ckpt_test)
# test registering a save function
trainer.enable_auto_hpc_walltime_manager()
# test model loading with a map_location
pretrained_model = load_model(exp, save_dir, True)
# test model preds
run_prediction(model.test_dataloader, pretrained_model)
trainer.model = pretrained_model
trainer.optimizers = pretrained_model.configure_optimizers()
# test HPC saving
saved_filepath = trainer.hpc_save(save_dir, exp)
assert os.path.exists(saved_filepath)
# test HPC loading
trainer.global_step = 20000000
trainer.hpc_load(save_dir, on_gpu=False)
assert trainer.global_step == real_global_step and trainer.global_step != 20000000
# test freeze on gpu
model.freeze()
model.unfreeze()
clear_save_dir()
def test_amp_gpu_ddp_slurm_managed():
"""
Make sure DDP + AMP work