test deprecated - model (#1074)
* pylint * model API * update test * formatting * disable logger * fix checking overwrite * fix test * typo * deprecated model * fix for DDP * drop Flake8 in GH actions * Update pytorch_lightning/trainer/evaluation_loop.py * fix imports Co-authored-by: Nic Eggert <nic@eggert.io>
This commit is contained in:
parent
792962ecc9
commit
3be81cb54e
|
@ -65,13 +65,14 @@ jobs:
|
|||
pip --version
|
||||
pip list
|
||||
|
||||
- name: Lint and Tests
|
||||
- name: Tests
|
||||
# env:
|
||||
# TOXENV: py${{ matrix.python-version }}
|
||||
run: |
|
||||
# tox --sitepackages
|
||||
flake8 .
|
||||
# flake8 .
|
||||
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml
|
||||
coverage report
|
||||
|
||||
- name: Upload pytest test results
|
||||
uses: actions/upload-artifact@master
|
||||
|
|
|
@ -7,7 +7,3 @@ import warnings
|
|||
|
||||
warnings.warn("`root_module` package has been renamed to `core` since v0.6.0."
|
||||
" The deprecated package name will be removed in v0.8.0.", DeprecationWarning)
|
||||
|
||||
from pytorch_lightning.core import ( # noqa: E402
|
||||
decorators, grads, hooks, root_module, memory, model_saving
|
||||
)
|
||||
|
|
|
@ -124,15 +124,16 @@ In this second case, the options you pass to trainer will be used when running
|
|||
"""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
|
||||
try:
|
||||
|
@ -215,7 +216,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
def reset_val_dataloader(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def evaluate(self, model, dataloaders, max_batches, test_mode: bool = False):
|
||||
def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
|
||||
"""Run evaluation code.
|
||||
|
||||
Args:
|
||||
|
@ -291,24 +292,23 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
outputs = outputs[0]
|
||||
|
||||
# give model a chance to do something with the outputs (and method defined)
|
||||
model = self.get_model()
|
||||
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
|
||||
model = model.module
|
||||
|
||||
if test_mode and self.is_overriden('test_epoch_end'):
|
||||
eval_results = model.test_epoch_end(outputs)
|
||||
elif self.is_overriden('validation_epoch_end'):
|
||||
eval_results = model.validation_epoch_end(outputs)
|
||||
|
||||
# TODO: remove in v 1.0.0
|
||||
if test_mode and self.is_overriden('test_end'):
|
||||
# TODO: remove in v1.0.0
|
||||
if test_mode and self.is_overriden('test_end', model=model):
|
||||
eval_results = model.test_end(outputs)
|
||||
m = 'test_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
|
||||
'Use test_epoch_end instead.'
|
||||
warnings.warn(m, DeprecationWarning)
|
||||
elif self.is_overriden('validation_end'):
|
||||
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
||||
' Use `test_epoch_end` instead.', DeprecationWarning)
|
||||
elif self.is_overriden('validation_end', model=model):
|
||||
eval_results = model.validation_end(outputs)
|
||||
m = 'validation_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
|
||||
'Use validation_epoch_end instead.'
|
||||
warnings.warn(m, DeprecationWarning)
|
||||
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
||||
' Use `validation_epoch_end` instead.', DeprecationWarning)
|
||||
|
||||
if test_mode and self.is_overriden('test_epoch_end', model=model):
|
||||
eval_results = model.test_epoch_end(outputs)
|
||||
elif self.is_overriden('validation_epoch_end', model=model):
|
||||
eval_results = model.validation_epoch_end(outputs)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
|
|
|
@ -11,13 +11,17 @@ class TrainerModelHooksMixin(ABC):
|
|||
f_op = getattr(model, f_name, None)
|
||||
return callable(f_op)
|
||||
|
||||
def is_overriden(self, f_name, model=None):
|
||||
def is_overriden(self, method_name: str, model: LightningModule = None) -> bool:
|
||||
if model is None:
|
||||
model = self.get_model()
|
||||
super_object = LightningModule
|
||||
|
||||
if not hasattr(model, method_name):
|
||||
# in case of calling deprecated method
|
||||
return False
|
||||
|
||||
# when code pointers are different, it was overriden
|
||||
is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
|
||||
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
|
||||
return is_overriden
|
||||
|
||||
def has_arg(self, f_name, arg_name):
|
||||
|
|
|
@ -282,7 +282,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
def train(self):
|
||||
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
' but will start from "0" in v0.8.0.', RuntimeWarning)
|
||||
|
||||
# get model
|
||||
model = self.get_model()
|
||||
|
|
|
@ -138,13 +138,13 @@ def test_adding_step_key(tmpdir):
|
|||
return decorated
|
||||
|
||||
model, hparams = tutils.get_model()
|
||||
model.validation_end = _validation_end
|
||||
model.validation_epoch_end = _validation_end
|
||||
trainer_options = dict(
|
||||
max_epochs=4,
|
||||
default_save_path=tmpdir,
|
||||
train_percent_check=0.001,
|
||||
val_percent_check=0.01,
|
||||
num_sanity_val_steps=0
|
||||
num_sanity_val_steps=0,
|
||||
)
|
||||
trainer = Trainer(**trainer_options)
|
||||
trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics)
|
||||
|
|
|
@ -225,8 +225,7 @@ class TestModelBase(LightningModule):
|
|||
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
|
||||
# training params (opt)
|
||||
parser.opt_list('--learning_rate', default=0.001 * 8, type=float,
|
||||
options=[0.0001, 0.0005, 0.001, 0.005],
|
||||
tunable=False)
|
||||
options=[0.0001, 0.0005, 0.001, 0.005], tunable=False)
|
||||
parser.opt_list('--optimizer_name', default='adam', type=str,
|
||||
options=['adam'], tunable=False)
|
||||
# if using 2 nodes with 4 gpus each the batch size here
|
||||
|
|
|
@ -2,8 +2,11 @@
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
import tests.models.utils as tutils
|
||||
from tests.models import TestModelBase, LightTrainDataloader, LightEmptyTestStep
|
||||
|
||||
def test_to_be_removed_in_v0_8_0_module_imports():
|
||||
|
||||
def test_tbd_remove_in_v0_8_0_module_imports():
|
||||
from pytorch_lightning.logging.comet_logger import CometLogger # noqa: F811
|
||||
from pytorch_lightning.logging.mlflow_logger import MLFlowLogger # noqa: F811
|
||||
from pytorch_lightning.logging.test_tube_logger import TestTubeLogger # noqa: F811
|
||||
|
@ -24,7 +27,7 @@ def test_to_be_removed_in_v0_8_0_module_imports():
|
|||
from pytorch_lightning.root_module.root_module import LightningModule # noqa: F811
|
||||
|
||||
|
||||
def test_to_be_removed_in_v0_8_0_trainer():
|
||||
def test_tbd_remove_in_v0_8_0_trainer():
|
||||
mapping_old_new = {
|
||||
'gradient_clip': 'gradient_clip_val',
|
||||
'nb_gpu_nodes': 'num_nodes',
|
||||
|
@ -45,7 +48,7 @@ def test_to_be_removed_in_v0_8_0_trainer():
|
|||
'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new)
|
||||
|
||||
|
||||
def test_to_be_removed_in_v0_9_0_module_imports():
|
||||
def test_tbd_remove_in_v0_9_0_module_imports():
|
||||
from pytorch_lightning.core.decorators import data_loader # noqa: F811
|
||||
|
||||
from pytorch_lightning.logging.comet import CometLogger # noqa: F402
|
||||
|
@ -53,3 +56,55 @@ def test_to_be_removed_in_v0_9_0_module_imports():
|
|||
from pytorch_lightning.logging.neptune import NeptuneLogger # noqa: F402
|
||||
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
|
||||
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402
|
||||
|
||||
|
||||
class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
|
||||
|
||||
# todo: this shall not be needed while evaluate asks for dataloader explicitly
|
||||
def val_dataloader(self):
|
||||
return self._dataloader(train=False)
|
||||
|
||||
def validation_end(self, outputs):
|
||||
return {'val_loss': 0.6}
|
||||
|
||||
def test_end(self, outputs):
|
||||
return {'test_loss': 0.6}
|
||||
|
||||
|
||||
class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
|
||||
|
||||
# todo: this shall not be needed while evaluate asks for dataloader explicitly
|
||||
def val_dataloader(self):
|
||||
return self._dataloader(train=False)
|
||||
|
||||
def validation_end(self, outputs):
|
||||
return {'val_loss': 0.7}
|
||||
|
||||
def test_end(self, outputs):
|
||||
return {'test_loss': 0.7}
|
||||
|
||||
|
||||
def test_tbd_remove_in_v1_0_0_model_hooks():
|
||||
hparams = tutils.get_hparams()
|
||||
|
||||
model = ModelVer0_6(hparams)
|
||||
|
||||
trainer = Trainer(logger=False)
|
||||
trainer.test(model)
|
||||
assert trainer.callback_metrics == {'test_loss': 0.6}
|
||||
|
||||
trainer = Trainer(logger=False)
|
||||
# TODO: why `dataloder` is required if it is not used
|
||||
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
|
||||
assert result == {'val_loss': 0.6}
|
||||
|
||||
model = ModelVer0_7(hparams)
|
||||
|
||||
trainer = Trainer(logger=False)
|
||||
trainer.test(model)
|
||||
assert trainer.callback_metrics == {'test_loss': 0.7}
|
||||
|
||||
trainer = Trainer(logger=False)
|
||||
# TODO: why `dataloder` is required if it is not used
|
||||
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
|
||||
assert result == {'val_loss': 0.7}
|
||||
|
|
Loading…
Reference in New Issue