Tests: refactor models (#1691)

* refactor default model

* drop redundant seeds

* drop redundant seeds

* refactor models tests

* refactor models tests

* imports

* fix conf

* Apply suggestions from code review
This commit is contained in:
Jirka Borovec 2020-05-04 17:38:08 +02:00 committed by GitHub
parent d28b145393
commit 1077159834
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 86 deletions

View File

@ -1535,8 +1535,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
else:
rank_zero_warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ "
f"contains argument 'hparams'. Will pass in an empty Namespace instead."
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
" contains argument 'hparams'. Will pass in an empty Namespace instead."
" Did you forget to store your model hyperparameters in self.hparams?"
)
hparams = Namespace()

View File

@ -249,8 +249,7 @@ def test_pickling(tmpdir):
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is
set correctly """
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
class CurrentTestModel(LightTrainDataloader, TestModelBase):

View File

@ -6,7 +6,7 @@ import torch
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel, EvalModelTemplate
from tests.base import EvalModelTemplate
@pytest.mark.spawn
@ -15,7 +15,6 @@ from tests.base import LightningTestModel, EvalModelTemplate
def test_amp_single_gpu(tmpdir, backend):
"""Make sure DP/DDP + AMP work."""
tutils.reset_seed()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
@ -63,8 +62,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
tutils.set_random_master_port()
os.environ['SLURM_LOCALID'] = str(0)
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# exp file to get meta
logger = tutils.get_default_logger(tmpdir)

View File

@ -7,16 +7,8 @@ from packaging.version import parse as version_parse
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
)
from tests.base import (
TestModelBase,
LightTrainDataloader,
LightningTestModel,
LightTestMixin,
EvalModelTemplate,
)
from pytorch_lightning.callbacks import EarlyStopping
from tests.base import EvalModelTemplate
def test_early_stopping_cpu_model(tmpdir):
@ -106,8 +98,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
def test_running_test_after_fitting(tmpdir):
"""Verify test() on fitted model."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
@ -138,11 +129,7 @@ def test_running_test_after_fitting(tmpdir):
def test_running_test_no_val(tmpdir):
"""Verify `test()` works on a model with no `val_loader`."""
class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
@ -220,8 +207,7 @@ def test_single_gpu_batch_parse():
def test_simple_cpu(tmpdir):
"""Verify continue training session on CPU."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# fit model
trainer = Trainer(
@ -285,7 +271,7 @@ def test_tbptt_cpu_model(tmpdir):
def __len__(self):
return 1
class BpttTestModel(LightTrainDataloader, TestModelBase):
class BpttTestModel(EvalModelTemplate):
def __init__(self, hparams):
super().__init__(hparams)
self.test_hidden = None

View File

@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core import memory
from pytorch_lightning.trainer.distrib_parts import parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel, EvalModelTemplate
from tests.base import EvalModelTemplate
PRETEND_N_OF_GPUS = 16
@ -65,7 +65,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
def test_cpu_slurm_save_load(tmpdir):
"""Verify model save/load/checkpoint on CPU."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)
# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
@ -112,7 +112,7 @@ def test_cpu_slurm_save_load(tmpdir):
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
)
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)
# set the epoch start hook so we can predict before the model does the full training
def assert_pred_same():

View File

@ -2,29 +2,19 @@ import pytest
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from tests.base import (
LightTrainDataloader,
LightValidationMixin,
TestModelBase,
LightTestMixin)
from tests.base import EvalModelTemplate
@pytest.mark.parametrize('max_steps', [1, 2, 3])
def test_on_before_zero_grad_called(max_steps):
class CurrentTestModel(
LightTrainDataloader,
LightValidationMixin,
LightTestMixin,
TestModelBase,
):
class CurrentTestModel(EvalModelTemplate):
on_before_zero_grad_called = 0
def on_before_zero_grad(self, optimizer):
self.on_before_zero_grad_called += 1
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
model = CurrentTestModel(tutils.get_default_hparams())
trainer = Trainer(
max_steps=max_steps,

View File

@ -11,7 +11,7 @@ import torch
from pytorch_lightning import Trainer
import tests.base.utils as tutils
from tests.base import LightningTestModel
from tests.base import EvalModelTemplate
from tests.base.models import TestGAN
try:
@ -107,7 +107,8 @@ def test_horovod_multi_gpu(tmpdir):
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_horovod_transfer_batch_to_gpu(tmpdir):
class TestTrainingStepModel(LightningTestModel):
class TestTrainingStepModel(EvalModelTemplate):
def training_step(self, batch, *args, **kwargs):
x, y = batch
assert str(x.device) != 'cpu'

View File

@ -9,11 +9,7 @@ import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
LightningTestModelWithoutHyperparametersArg,
LightningTestModelWithUnusedHyperparametersArg
)
from tests.base import EvalModelTemplate
@pytest.mark.spawn
@ -23,8 +19,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
"""Verify `test()` on pretrained model."""
tutils.set_random_master_port()
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# exp file to get meta
logger = tutils.get_default_logger(tmpdir)
@ -53,7 +48,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)
module_class=EvalModelTemplate)
# run test set
new_trainer = Trainer(**trainer_options)
@ -72,8 +67,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
def test_running_test_pretrained_model_cpu(tmpdir):
"""Verify test() on pretrained model."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
@ -97,7 +91,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(
logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
logger, trainer.checkpoint_callback.dirpath, module_class=EvalModelTemplate
)
new_trainer = Trainer(**trainer_options)
@ -110,7 +104,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
def test_load_model_from_checkpoint(tmpdir):
"""Verify test() on pretrained model."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)
trainer_options = dict(
progress_bar_refresh_rate=0,
@ -131,7 +125,7 @@ def test_load_model_from_checkpoint(tmpdir):
# load last checkpoint
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)
pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint)
# test that hparams loaded correctly
for k, v in vars(hparams).items():
@ -152,7 +146,13 @@ def test_load_model_from_checkpoint(tmpdir):
def test_dp_resume(tmpdir):
"""Make sure DP continues training correctly."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)
trainer_options = dict(
max_epochs=1,
gpus=2,
distributed_backend='dp',
)
# get logger
logger = tutils.get_default_logger(tmpdir)
@ -161,13 +161,9 @@ def test_dp_resume(tmpdir):
# logger file to get weights
checkpoint = tutils.init_checkpoint_callback(logger)
trainer_options = dict(
max_epochs=1,
gpus=2,
distributed_backend='dp',
logger=logger,
checkpoint_callback=checkpoint,
)
# add these to the trainer options
trainer_options['logger'] = logger
trainer_options['checkpoint_callback'] = checkpoint
# fit model
trainer = Trainer(**trainer_options)
@ -188,13 +184,11 @@ def test_dp_resume(tmpdir):
# init new trainer
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
trainer_options.update(
logger=new_logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
train_percent_check=0.5,
val_percent_check=0.2,
max_epochs=1,
)
trainer_options['logger'] = new_logger
trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
trainer_options['train_percent_check'] = 0.5
trainer_options['val_percent_check'] = 0.2
trainer_options['max_epochs'] = 1
new_trainer = Trainer(**trainer_options)
# set the epoch start hook so we can predict before the model does the full training
@ -210,7 +204,7 @@ def test_dp_resume(tmpdir):
tutils.run_prediction(dataloader, dp_model, dp=True)
# new model
model = LightningTestModel(hparams)
model = EvalModelTemplate(hparams)
model.on_train_start = assert_good_acc
# fit new model which should load hpc weights
@ -223,18 +217,19 @@ def test_dp_resume(tmpdir):
def test_model_saving_loading(tmpdir):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
model = EvalModelTemplate(tutils.get_default_hparams())
# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
# fit model
trainer = Trainer(
trainer_options = dict(
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir)
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# traning complete
@ -263,7 +258,7 @@ def test_model_saving_loading(tmpdir):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_checkpoint(
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
@ -276,8 +271,7 @@ def test_model_saving_loading(tmpdir):
def test_load_model_with_missing_hparams(tmpdir):
# fit model
trainer = Trainer(
trainer_options = dict(
progress_bar_refresh_rate=0,
max_epochs=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
@ -285,22 +279,35 @@ def test_load_model_with_missing_hparams(tmpdir):
default_root_dir=tmpdir,
)
model = LightningTestModelWithoutHyperparametersArg()
# fit model
trainer = Trainer(**trainer_options)
class CurrentModelWithoutHparams(EvalModelTemplate):
def __init__(self):
hparams = tutils.get_default_hparams()
super().__init__(hparams)
class CurrentModelUnusedHparams(EvalModelTemplate):
def __init__(self, hparams):
hparams = tutils.get_default_hparams()
super().__init__(hparams)
model = CurrentModelWithoutHparams()
trainer.fit(model)
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
# try to load a checkpoint that has hparams but model is missing hparams arg
with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"):
LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint)
CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint)
# create a checkpoint without hyperparameters
# if the model does not take a hparams argument, it should not throw an error
ckpt = torch.load(last_checkpoint)
del(ckpt['hparams'])
torch.save(ckpt, last_checkpoint)
LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint)
CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint)
# load checkpoint without hparams again
# warn if user's model has hparams argument
with pytest.warns(UserWarning, match=r".*Will pass in an empty Namespace instead."):
LightningTestModelWithUnusedHyperparametersArg.load_from_checkpoint(last_checkpoint)
CurrentModelUnusedHparams.load_from_checkpoint(last_checkpoint)