Tests: refactor models (#1691)
* refactor default model * drop redundant seeds * drop redundant seeds * refactor models tests * refactor models tests * imports * fix conf * Apply suggestions from code review
This commit is contained in:
parent
d28b145393
commit
1077159834
|
@ -1535,8 +1535,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
|
||||
else:
|
||||
rank_zero_warn(
|
||||
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ "
|
||||
f"contains argument 'hparams'. Will pass in an empty Namespace instead."
|
||||
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
|
||||
" contains argument 'hparams'. Will pass in an empty Namespace instead."
|
||||
" Did you forget to store your model hyperparameters in self.hparams?"
|
||||
)
|
||||
hparams = Namespace()
|
||||
|
|
|
@ -249,8 +249,7 @@ def test_pickling(tmpdir):
|
|||
|
||||
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
|
||||
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
||||
""" Test that None in checkpoint callback is valid and that chkp_path is
|
||||
set correctly """
|
||||
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentTestModel(LightTrainDataloader, TestModelBase):
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import LightningTestModel, EvalModelTemplate
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.spawn
|
||||
|
@ -15,7 +15,6 @@ from tests.base import LightningTestModel, EvalModelTemplate
|
|||
def test_amp_single_gpu(tmpdir, backend):
|
||||
"""Make sure DP/DDP + AMP work."""
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
|
@ -63,8 +62,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
|
|||
tutils.set_random_master_port()
|
||||
os.environ['SLURM_LOCALID'] = str(0)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# exp file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
|
|
@ -7,16 +7,8 @@ from packaging.version import parse as version_parse
|
|||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
)
|
||||
from tests.base import (
|
||||
TestModelBase,
|
||||
LightTrainDataloader,
|
||||
LightningTestModel,
|
||||
LightTestMixin,
|
||||
EvalModelTemplate,
|
||||
)
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
def test_early_stopping_cpu_model(tmpdir):
|
||||
|
@ -106,8 +98,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
|
|||
|
||||
def test_running_test_after_fitting(tmpdir):
|
||||
"""Verify test() on fitted model."""
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
@ -138,11 +129,7 @@ def test_running_test_after_fitting(tmpdir):
|
|||
|
||||
def test_running_test_no_val(tmpdir):
|
||||
"""Verify `test()` works on a model with no `val_loader`."""
|
||||
class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
@ -220,8 +207,7 @@ def test_single_gpu_batch_parse():
|
|||
|
||||
def test_simple_cpu(tmpdir):
|
||||
"""Verify continue training session on CPU."""
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
|
@ -285,7 +271,7 @@ def test_tbptt_cpu_model(tmpdir):
|
|||
def __len__(self):
|
||||
return 1
|
||||
|
||||
class BpttTestModel(LightTrainDataloader, TestModelBase):
|
||||
class BpttTestModel(EvalModelTemplate):
|
||||
def __init__(self, hparams):
|
||||
super().__init__(hparams)
|
||||
self.test_hidden = None
|
||||
|
|
|
@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
|||
from pytorch_lightning.core import memory
|
||||
from pytorch_lightning.trainer.distrib_parts import parse_gpu_ids, determine_root_gpu_device
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import LightningTestModel, EvalModelTemplate
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
PRETEND_N_OF_GPUS = 16
|
||||
|
||||
|
@ -65,7 +65,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
|
|||
def test_cpu_slurm_save_load(tmpdir):
|
||||
"""Verify model save/load/checkpoint on CPU."""
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(hparams)
|
||||
|
||||
# logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
@ -112,7 +112,7 @@ def test_cpu_slurm_save_load(tmpdir):
|
|||
logger=logger,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir),
|
||||
)
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(hparams)
|
||||
|
||||
# set the epoch start hook so we can predict before the model does the full training
|
||||
def assert_pred_same():
|
||||
|
|
|
@ -2,29 +2,19 @@ import pytest
|
|||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base import (
|
||||
LightTrainDataloader,
|
||||
LightValidationMixin,
|
||||
TestModelBase,
|
||||
LightTestMixin)
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.parametrize('max_steps', [1, 2, 3])
|
||||
def test_on_before_zero_grad_called(max_steps):
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightValidationMixin,
|
||||
LightTestMixin,
|
||||
TestModelBase,
|
||||
):
|
||||
class CurrentTestModel(EvalModelTemplate):
|
||||
on_before_zero_grad_called = 0
|
||||
|
||||
def on_before_zero_grad(self, optimizer):
|
||||
self.on_before_zero_grad_called += 1
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
model = CurrentTestModel(tutils.get_default_hparams())
|
||||
|
||||
trainer = Trainer(
|
||||
max_steps=max_steps,
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from tests.base import LightningTestModel
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base.models import TestGAN
|
||||
|
||||
try:
|
||||
|
@ -107,7 +107,8 @@ def test_horovod_multi_gpu(tmpdir):
|
|||
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_horovod_transfer_batch_to_gpu(tmpdir):
|
||||
class TestTrainingStepModel(LightningTestModel):
|
||||
|
||||
class TestTrainingStepModel(EvalModelTemplate):
|
||||
def training_step(self, batch, *args, **kwargs):
|
||||
x, y = batch
|
||||
assert str(x.device) != 'cpu'
|
||||
|
|
|
@ -9,11 +9,7 @@ import tests.base.utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import (
|
||||
LightningTestModel,
|
||||
LightningTestModelWithoutHyperparametersArg,
|
||||
LightningTestModelWithUnusedHyperparametersArg
|
||||
)
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.spawn
|
||||
|
@ -23,8 +19,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
|
|||
"""Verify `test()` on pretrained model."""
|
||||
tutils.set_random_master_port()
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# exp file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
@ -53,7 +48,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
|
|||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = tutils.load_model(logger,
|
||||
trainer.checkpoint_callback.dirpath,
|
||||
module_class=LightningTestModel)
|
||||
module_class=EvalModelTemplate)
|
||||
|
||||
# run test set
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
|
@ -72,8 +67,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
|
|||
|
||||
def test_running_test_pretrained_model_cpu(tmpdir):
|
||||
"""Verify test() on pretrained model."""
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
@ -97,7 +91,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
|
|||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = tutils.load_model(
|
||||
logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
|
||||
logger, trainer.checkpoint_callback.dirpath, module_class=EvalModelTemplate
|
||||
)
|
||||
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
|
@ -110,7 +104,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
|
|||
def test_load_model_from_checkpoint(tmpdir):
|
||||
"""Verify test() on pretrained model."""
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar_refresh_rate=0,
|
||||
|
@ -131,7 +125,7 @@ def test_load_model_from_checkpoint(tmpdir):
|
|||
|
||||
# load last checkpoint
|
||||
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
|
||||
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)
|
||||
pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint)
|
||||
|
||||
# test that hparams loaded correctly
|
||||
for k, v in vars(hparams).items():
|
||||
|
@ -152,7 +146,13 @@ def test_load_model_from_checkpoint(tmpdir):
|
|||
def test_dp_resume(tmpdir):
|
||||
"""Make sure DP continues training correctly."""
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
max_epochs=1,
|
||||
gpus=2,
|
||||
distributed_backend='dp',
|
||||
)
|
||||
|
||||
# get logger
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
@ -161,13 +161,9 @@ def test_dp_resume(tmpdir):
|
|||
# logger file to get weights
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
trainer_options = dict(
|
||||
max_epochs=1,
|
||||
gpus=2,
|
||||
distributed_backend='dp',
|
||||
logger=logger,
|
||||
checkpoint_callback=checkpoint,
|
||||
)
|
||||
# add these to the trainer options
|
||||
trainer_options['logger'] = logger
|
||||
trainer_options['checkpoint_callback'] = checkpoint
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
|
@ -188,13 +184,11 @@ def test_dp_resume(tmpdir):
|
|||
|
||||
# init new trainer
|
||||
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
|
||||
trainer_options.update(
|
||||
logger=new_logger,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir),
|
||||
train_percent_check=0.5,
|
||||
val_percent_check=0.2,
|
||||
max_epochs=1,
|
||||
)
|
||||
trainer_options['logger'] = new_logger
|
||||
trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
|
||||
trainer_options['train_percent_check'] = 0.5
|
||||
trainer_options['val_percent_check'] = 0.2
|
||||
trainer_options['max_epochs'] = 1
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
|
||||
# set the epoch start hook so we can predict before the model does the full training
|
||||
|
@ -210,7 +204,7 @@ def test_dp_resume(tmpdir):
|
|||
tutils.run_prediction(dataloader, dp_model, dp=True)
|
||||
|
||||
# new model
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(hparams)
|
||||
model.on_train_start = assert_good_acc
|
||||
|
||||
# fit new model which should load hpc weights
|
||||
|
@ -223,18 +217,19 @@ def test_dp_resume(tmpdir):
|
|||
|
||||
def test_model_saving_loading(tmpdir):
|
||||
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
model = EvalModelTemplate(tutils.get_default_hparams())
|
||||
|
||||
# logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
trainer_options = dict(
|
||||
max_epochs=1,
|
||||
logger=logger,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir)
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# traning complete
|
||||
|
@ -263,7 +258,7 @@ def test_model_saving_loading(tmpdir):
|
|||
# load new model
|
||||
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
model_2 = LightningTestModel.load_from_checkpoint(
|
||||
model_2 = EvalModelTemplate.load_from_checkpoint(
|
||||
checkpoint_path=new_weights_path,
|
||||
tags_csv=tags_path
|
||||
)
|
||||
|
@ -276,8 +271,7 @@ def test_model_saving_loading(tmpdir):
|
|||
|
||||
|
||||
def test_load_model_with_missing_hparams(tmpdir):
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
trainer_options = dict(
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
|
||||
|
@ -285,22 +279,35 @@ def test_load_model_with_missing_hparams(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
)
|
||||
|
||||
model = LightningTestModelWithoutHyperparametersArg()
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
|
||||
class CurrentModelWithoutHparams(EvalModelTemplate):
|
||||
def __init__(self):
|
||||
hparams = tutils.get_default_hparams()
|
||||
super().__init__(hparams)
|
||||
|
||||
class CurrentModelUnusedHparams(EvalModelTemplate):
|
||||
def __init__(self, hparams):
|
||||
hparams = tutils.get_default_hparams()
|
||||
super().__init__(hparams)
|
||||
|
||||
model = CurrentModelWithoutHparams()
|
||||
trainer.fit(model)
|
||||
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
|
||||
|
||||
# try to load a checkpoint that has hparams but model is missing hparams arg
|
||||
with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"):
|
||||
LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint)
|
||||
CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint)
|
||||
|
||||
# create a checkpoint without hyperparameters
|
||||
# if the model does not take a hparams argument, it should not throw an error
|
||||
ckpt = torch.load(last_checkpoint)
|
||||
del(ckpt['hparams'])
|
||||
torch.save(ckpt, last_checkpoint)
|
||||
LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint)
|
||||
CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint)
|
||||
|
||||
# load checkpoint without hparams again
|
||||
# warn if user's model has hparams argument
|
||||
with pytest.warns(UserWarning, match=r".*Will pass in an empty Namespace instead."):
|
||||
LightningTestModelWithUnusedHyperparametersArg.load_from_checkpoint(last_checkpoint)
|
||||
CurrentModelUnusedHparams.load_from_checkpoint(last_checkpoint)
|
||||
|
|
Loading…
Reference in New Issue