fixed lr scheduler tests

This commit is contained in:
williamFalcon 2019-07-28 06:21:41 -07:00
parent 27660b8a96
commit b9e0d841dc
1 changed files with 29 additions and 29 deletions

View File

@ -24,6 +24,33 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
# TESTS # TESTS
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def test_amp_gpu_ddp():
"""
Make sure DDP + AMP work
:return:
"""
if not torch.cuda.is_available():
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test')
return
if not torch.cuda.device_count() > 1:
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
return
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
hparams = get_hparams()
model = LightningTestModel(hparams)
trainer_options = dict(
progress_bar=True,
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='ddp',
use_amp=True
)
run_gpu_model_test(trainer_options, model, hparams)
def test_cpu_slurm_save_load(): def test_cpu_slurm_save_load():
""" """
@ -280,7 +307,7 @@ def test_amp_gpu_ddp_slurm_managed():
if trainer.use_ddp: if trainer.use_ddp:
# on hpc this would work fine... but need to hack it for the purpose of the test # on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model trainer.model = pretrained_model
trainer.optimizers = pretrained_model.configure_optimizers() trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
# test HPC loading / saving # test HPC loading / saving
trainer.hpc_save(save_dir, exp) trainer.hpc_save(save_dir, exp)
@ -477,33 +504,6 @@ def test_multi_gpu_model_ddp():
run_gpu_model_test(trainer_options, model, hparams) run_gpu_model_test(trainer_options, model, hparams)
def test_amp_gpu_ddp():
"""
Make sure DDP + AMP work
:return:
"""
if not torch.cuda.is_available():
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test')
return
if not torch.cuda.device_count() > 1:
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
return
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
hparams = get_hparams()
model = LightningTestModel(hparams)
trainer_options = dict(
progress_bar=True,
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='ddp',
use_amp=True
)
run_gpu_model_test(trainer_options, model, hparams)
def test_ddp_sampler_error(): def test_ddp_sampler_error():
""" """
@ -574,7 +574,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
if trainer.use_ddp: if trainer.use_ddp:
# on hpc this would work fine... but need to hack it for the purpose of the test # on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model trainer.model = pretrained_model
trainer.optimizers = pretrained_model.configure_optimizers() trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
# test HPC loading / saving # test HPC loading / saving
trainer.hpc_save(save_dir, exp) trainer.hpc_save(save_dir, exp)