added cpu 16 bit

This commit is contained in:
William Falcon 2019-07-24 19:05:46 -04:00
parent ed9d977c4a
commit efbd1a1c18
1 changed files with 22 additions and 20 deletions

View File

@ -21,6 +21,28 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
# TESTS # TESTS
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def test_cpu_model_with_amp():
"""
Make sure model trains on CPU
:return:
"""
trainer_options = dict(
progress_bar=False,
experiment=get_exp(),
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4,
use_amp=True
)
model, hparams = get_model()
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
def test_amp_gpu_ddp_slurm_managed(): def test_amp_gpu_ddp_slurm_managed():
""" """
Make sure DDP + AMP work Make sure DDP + AMP work
@ -111,26 +133,6 @@ def test_cpu_model():
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
def test_cpu_model_with_amp():
"""
Make sure model trains on CPU
:return:
"""
trainer_options = dict(
progress_bar=False,
experiment=get_exp(),
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4,
use_amp=True
)
model, hparams = get_model()
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
def test_all_features_cpu_model(): def test_all_features_cpu_model():
""" """
Test each of the trainer options Test each of the trainer options