fixed amp bug

This commit is contained in:
William Falcon 2019-07-24 14:17:36 -04:00
parent 5fe833ae01
commit 9e187574de
1 changed files with 6 additions and 14 deletions

View File

@ -23,24 +23,16 @@ def test_cpu_model():
Make sure model trains on CPU Make sure model trains on CPU
:return: :return:
""" """
save_dir = init_save_dir()
model, hparams = get_model() trainer_options = dict(
trainer = Trainer(
progress_bar=False, progress_bar=False,
experiment=get_exp(), experiment=get_exp(),
max_nb_epochs=1, max_nb_epochs=1,
train_percent_check=0.4, train_percent_check=0.4,
val_percent_check=0.4 val_percent_check=0.4
) )
result = trainer.fit(model)
# correct result and ok accuracy run_gpu_model_test(trainer_options, on_gpu=False)
assert result == 1, 'cpu model failed to complete'
assert_ok_acc(trainer)
clear_save_dir()
def test_single_gpu_model(): def test_single_gpu_model():
@ -162,7 +154,7 @@ def test_amp_gpu_ddp():
# UTILS # UTILS
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def run_gpu_model_test(trainer_options): def run_gpu_model_test(trainer_options, on_gpu=True):
""" """
Make sure DDP + AMP work Make sure DDP + AMP work
:return: :return:
@ -197,7 +189,7 @@ def run_gpu_model_test(trainer_options):
assert result == 1, 'amp + ddp model failed to complete' assert result == 1, 'amp + ddp model failed to complete'
# test model loading # test model loading
pretrained_model = load_model(exp, save_dir) pretrained_model = load_model(exp, save_dir, on_gpu)
# test model preds # test model preds
run_prediction(model.test_dataloader, pretrained_model) run_prediction(model.test_dataloader, pretrained_model)
@ -247,7 +239,7 @@ def clear_save_dir():
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
def load_model(exp, save_dir): def load_model(exp, save_dir, on_gpu):
# load trained model # load trained model
tags_path = exp.get_data_path(exp.name, exp.version) tags_path = exp.get_data_path(exp.name, exp.version)
@ -256,7 +248,7 @@ def load_model(exp, save_dir):
checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x] checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x]
weights_dir = os.path.join(save_dir, checkpoints[0]) weights_dir = os.path.join(save_dir, checkpoints[0])
trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=True) trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=on_gpu)
assert trained_model is not None, 'loading model failed' assert trained_model is not None, 'loading model failed'