diff --git a/tests/test_models.py b/tests/test_models.py index 5ead6de8ba..768acd05b7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -90,7 +90,6 @@ def test_model_saving_loading(): # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) - real_global_step = trainer.global_step # traning complete assert result == 1, 'amp + ddp model failed to complete' @@ -124,7 +123,7 @@ def test_model_saving_loading(): clear_save_dir() -def test_cpu_slurm_saving_loading(): +def test_cpu_slurm_save_load(): """ Verify model save/load/checkpoint on CPU :return: @@ -154,38 +153,49 @@ def test_cpu_slurm_saving_loading(): # traning complete assert result == 1, 'amp + ddp model failed to complete' - # test saving checkpoint - ckpt_test = os.path.join(save_dir, 'test.ckpt') - trainer.save_checkpoint(ckpt_test) + # predict with trained model before saving + # make a prediction + for batch in model.test_dataloader: + break + + x, y = batch + x = x.view(x.size(0), -1) + + model.eval() + pred_before_saving = model(x) # test registering a save function trainer.enable_auto_hpc_walltime_manager() - # test model loading with a map_location - pretrained_model = load_model(exp, save_dir, True) - - # test model preds - run_prediction(model.test_dataloader, pretrained_model) - - trainer.model = pretrained_model - trainer.optimizers = pretrained_model.configure_optimizers() - # test HPC saving + # simulate snapshot on slurm saved_filepath = trainer.hpc_save(save_dir, exp) assert os.path.exists(saved_filepath) + # wipe-out trainer model + # we want to see if the weights come back correctly + trainer.model = LightningTestModel(hparams) + # test HPC loading trainer.global_step = 20000000 trainer.hpc_load(save_dir, on_gpu=False) assert trainer.global_step == real_global_step and trainer.global_step != 20000000 - # test freeze on gpu - model.freeze() - model.unfreeze() + # predict with loaded model to make sure answers are the same + new_pred = trainer.model(x) + assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 clear_save_dir() +def test_model_freeze_unfreeze(): + hparams = get_hparams() + model = LightningTestModel(hparams) + + model.freeze() + model.unfreeze() + + def test_amp_gpu_ddp_slurm_managed(): """ Make sure DDP + AMP work