fixed lr scheduler tests
This commit is contained in:
parent
27660b8a96
commit
b9e0d841dc
|
@ -24,6 +24,33 @@ np.random.seed(SEED)
|
||||||
# ------------------------------------------------------------------------
|
# ------------------------------------------------------------------------
|
||||||
# TESTS
|
# TESTS
|
||||||
# ------------------------------------------------------------------------
|
# ------------------------------------------------------------------------
|
||||||
|
def test_amp_gpu_ddp():
|
||||||
|
"""
|
||||||
|
Make sure DDP + AMP work
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test')
|
||||||
|
return
|
||||||
|
if not torch.cuda.device_count() > 1:
|
||||||
|
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
|
||||||
|
return
|
||||||
|
|
||||||
|
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||||
|
|
||||||
|
hparams = get_hparams()
|
||||||
|
model = LightningTestModel(hparams)
|
||||||
|
|
||||||
|
trainer_options = dict(
|
||||||
|
progress_bar=True,
|
||||||
|
max_nb_epochs=1,
|
||||||
|
gpus=[0, 1],
|
||||||
|
distributed_backend='ddp',
|
||||||
|
use_amp=True
|
||||||
|
)
|
||||||
|
|
||||||
|
run_gpu_model_test(trainer_options, model, hparams)
|
||||||
|
|
||||||
|
|
||||||
def test_cpu_slurm_save_load():
|
def test_cpu_slurm_save_load():
|
||||||
"""
|
"""
|
||||||
|
@ -280,7 +307,7 @@ def test_amp_gpu_ddp_slurm_managed():
|
||||||
if trainer.use_ddp:
|
if trainer.use_ddp:
|
||||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||||
trainer.model = pretrained_model
|
trainer.model = pretrained_model
|
||||||
trainer.optimizers = pretrained_model.configure_optimizers()
|
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
|
||||||
|
|
||||||
# test HPC loading / saving
|
# test HPC loading / saving
|
||||||
trainer.hpc_save(save_dir, exp)
|
trainer.hpc_save(save_dir, exp)
|
||||||
|
@ -477,33 +504,6 @@ def test_multi_gpu_model_ddp():
|
||||||
run_gpu_model_test(trainer_options, model, hparams)
|
run_gpu_model_test(trainer_options, model, hparams)
|
||||||
|
|
||||||
|
|
||||||
def test_amp_gpu_ddp():
|
|
||||||
"""
|
|
||||||
Make sure DDP + AMP work
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test')
|
|
||||||
return
|
|
||||||
if not torch.cuda.device_count() > 1:
|
|
||||||
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
|
|
||||||
return
|
|
||||||
|
|
||||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
|
||||||
|
|
||||||
hparams = get_hparams()
|
|
||||||
model = LightningTestModel(hparams)
|
|
||||||
|
|
||||||
trainer_options = dict(
|
|
||||||
progress_bar=True,
|
|
||||||
max_nb_epochs=1,
|
|
||||||
gpus=[0, 1],
|
|
||||||
distributed_backend='ddp',
|
|
||||||
use_amp=True
|
|
||||||
)
|
|
||||||
|
|
||||||
run_gpu_model_test(trainer_options, model, hparams)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ddp_sampler_error():
|
def test_ddp_sampler_error():
|
||||||
"""
|
"""
|
||||||
|
@ -574,7 +574,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
||||||
if trainer.use_ddp:
|
if trainer.use_ddp:
|
||||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||||
trainer.model = pretrained_model
|
trainer.model = pretrained_model
|
||||||
trainer.optimizers = pretrained_model.configure_optimizers()
|
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
|
||||||
|
|
||||||
# test HPC loading / saving
|
# test HPC loading / saving
|
||||||
trainer.hpc_save(save_dir, exp)
|
trainer.hpc_save(save_dir, exp)
|
||||||
|
|
Loading…
Reference in New Issue