From ed9d977c4addf708c7cd4cdcae2ef8c292e0461e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 19:05:20 -0400 Subject: [PATCH] added cpu 16 bit --- pytorch_lightning/models/trainer.py | 2 +- tests/test_models.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 0be187d0f8..ab1d856549 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -344,7 +344,7 @@ class Trainer(TrainerIO): # run training for batch_i, data_batch in enumerate(dataloader): - if data_batch is None: + if data_batch is None: # pragma: no cover continue # stop short when on fast dev run diff --git a/tests/test_models.py b/tests/test_models.py index ed627fcc13..84383b453a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -111,6 +111,26 @@ def test_cpu_model(): run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) +def test_cpu_model_with_amp(): + """ + Make sure model trains on CPU + :return: + """ + + trainer_options = dict( + progress_bar=False, + experiment=get_exp(), + max_nb_epochs=1, + train_percent_check=0.4, + val_percent_check=0.4, + use_amp=True + ) + + model, hparams = get_model() + + run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) + + def test_all_features_cpu_model(): """ Test each of the trainer options