fixed amp bug
This commit is contained in:
parent
5fe833ae01
commit
9e187574de
|
@ -23,24 +23,16 @@ def test_cpu_model():
|
||||||
Make sure model trains on CPU
|
Make sure model trains on CPU
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
save_dir = init_save_dir()
|
|
||||||
|
|
||||||
model, hparams = get_model()
|
trainer_options = dict(
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
progress_bar=False,
|
progress_bar=False,
|
||||||
experiment=get_exp(),
|
experiment=get_exp(),
|
||||||
max_nb_epochs=1,
|
max_nb_epochs=1,
|
||||||
train_percent_check=0.4,
|
train_percent_check=0.4,
|
||||||
val_percent_check=0.4
|
val_percent_check=0.4
|
||||||
)
|
)
|
||||||
result = trainer.fit(model)
|
|
||||||
|
|
||||||
# correct result and ok accuracy
|
run_gpu_model_test(trainer_options, on_gpu=False)
|
||||||
assert result == 1, 'cpu model failed to complete'
|
|
||||||
assert_ok_acc(trainer)
|
|
||||||
|
|
||||||
clear_save_dir()
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_gpu_model():
|
def test_single_gpu_model():
|
||||||
|
@ -162,7 +154,7 @@ def test_amp_gpu_ddp():
|
||||||
# UTILS
|
# UTILS
|
||||||
# ------------------------------------------------------------------------
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
def run_gpu_model_test(trainer_options):
|
def run_gpu_model_test(trainer_options, on_gpu=True):
|
||||||
"""
|
"""
|
||||||
Make sure DDP + AMP work
|
Make sure DDP + AMP work
|
||||||
:return:
|
:return:
|
||||||
|
@ -197,7 +189,7 @@ def run_gpu_model_test(trainer_options):
|
||||||
assert result == 1, 'amp + ddp model failed to complete'
|
assert result == 1, 'amp + ddp model failed to complete'
|
||||||
|
|
||||||
# test model loading
|
# test model loading
|
||||||
pretrained_model = load_model(exp, save_dir)
|
pretrained_model = load_model(exp, save_dir, on_gpu)
|
||||||
|
|
||||||
# test model preds
|
# test model preds
|
||||||
run_prediction(model.test_dataloader, pretrained_model)
|
run_prediction(model.test_dataloader, pretrained_model)
|
||||||
|
@ -247,7 +239,7 @@ def clear_save_dir():
|
||||||
shutil.rmtree(save_dir)
|
shutil.rmtree(save_dir)
|
||||||
|
|
||||||
|
|
||||||
def load_model(exp, save_dir):
|
def load_model(exp, save_dir, on_gpu):
|
||||||
|
|
||||||
# load trained model
|
# load trained model
|
||||||
tags_path = exp.get_data_path(exp.name, exp.version)
|
tags_path = exp.get_data_path(exp.name, exp.version)
|
||||||
|
@ -256,7 +248,7 @@ def load_model(exp, save_dir):
|
||||||
checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x]
|
checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x]
|
||||||
weights_dir = os.path.join(save_dir, checkpoints[0])
|
weights_dir = os.path.join(save_dir, checkpoints[0])
|
||||||
|
|
||||||
trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=True)
|
trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=on_gpu)
|
||||||
|
|
||||||
assert trained_model is not None, 'loading model failed'
|
assert trained_model is not None, 'loading model failed'
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue