added cpu 16 bit
This commit is contained in:
parent
ed9d977c4a
commit
efbd1a1c18
|
@ -21,6 +21,28 @@ np.random.seed(SEED)
|
|||
# ------------------------------------------------------------------------
|
||||
# TESTS
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_cpu_model_with_amp():
|
||||
"""
|
||||
Make sure model trains on CPU
|
||||
:return:
|
||||
"""
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
experiment=get_exp(),
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4,
|
||||
use_amp=True
|
||||
)
|
||||
|
||||
model, hparams = get_model()
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_amp_gpu_ddp_slurm_managed():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
|
@ -111,26 +133,6 @@ def test_cpu_model():
|
|||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_cpu_model_with_amp():
|
||||
"""
|
||||
Make sure model trains on CPU
|
||||
:return:
|
||||
"""
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
experiment=get_exp(),
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4,
|
||||
use_amp=True
|
||||
)
|
||||
|
||||
model, hparams = get_model()
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_all_features_cpu_model():
|
||||
"""
|
||||
Test each of the trainer options
|
||||
|
|
Loading…
Reference in New Issue