added cpu 16 bit
This commit is contained in:
parent
65ce10c255
commit
ed9d977c4a
|
@ -344,7 +344,7 @@ class Trainer(TrainerIO):
|
|||
# run training
|
||||
for batch_i, data_batch in enumerate(dataloader):
|
||||
|
||||
if data_batch is None:
|
||||
if data_batch is None: # pragma: no cover
|
||||
continue
|
||||
|
||||
# stop short when on fast dev run
|
||||
|
|
|
@ -111,6 +111,26 @@ def test_cpu_model():
|
|||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_cpu_model_with_amp():
|
||||
"""
|
||||
Make sure model trains on CPU
|
||||
:return:
|
||||
"""
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
experiment=get_exp(),
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4,
|
||||
use_amp=True
|
||||
)
|
||||
|
||||
model, hparams = get_model()
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_all_features_cpu_model():
|
||||
"""
|
||||
Test each of the trainer options
|
||||
|
|
Loading…
Reference in New Issue