test deprecated - model (#1074)

* pylint

* model API

* update test

* formatting

* disable logger

* fix checking overwrite

* fix test

* typo

* deprecated model

* fix for DDP

* drop Flake8 in GH actions

* Update pytorch_lightning/trainer/evaluation_loop.py

* fix imports

Co-authored-by: Nic Eggert <nic@eggert.io>
This commit is contained in:
Jirka Borovec 2020-03-20 20:51:14 +01:00 committed by GitHub
parent 792962ecc9
commit 3be81cb54e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 33 deletions

View File

@ -65,13 +65,14 @@ jobs:
pip --version
pip list
- name: Lint and Tests
- name: Tests
# env:
# TOXENV: py${{ matrix.python-version }}
run: |
# tox --sitepackages
flake8 .
# flake8 .
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml
coverage report
- name: Upload pytest test results
uses: actions/upload-artifact@master

View File

@ -7,7 +7,3 @@ import warnings
warnings.warn("`root_module` package has been renamed to `core` since v0.6.0."
" The deprecated package name will be removed in v0.8.0.", DeprecationWarning)
from pytorch_lightning.core import ( # noqa: E402
decorators, grads, hooks, root_module, memory, model_saving
)

View File

@ -124,15 +124,16 @@ In this second case, the options you pass to trainer will be used when running
"""
import sys
import warnings
from abc import ABC, abstractmethod
from typing import Callable
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import warnings
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
@ -215,7 +216,7 @@ class TrainerEvaluationLoopMixin(ABC):
def reset_val_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def evaluate(self, model, dataloaders, max_batches, test_mode: bool = False):
def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
"""Run evaluation code.
Args:
@ -291,24 +292,23 @@ class TrainerEvaluationLoopMixin(ABC):
outputs = outputs[0]
# give model a chance to do something with the outputs (and method defined)
model = self.get_model()
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
model = model.module
if test_mode and self.is_overriden('test_epoch_end'):
eval_results = model.test_epoch_end(outputs)
elif self.is_overriden('validation_epoch_end'):
eval_results = model.validation_epoch_end(outputs)
# TODO: remove in v 1.0.0
if test_mode and self.is_overriden('test_end'):
# TODO: remove in v1.0.0
if test_mode and self.is_overriden('test_end', model=model):
eval_results = model.test_end(outputs)
m = 'test_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
'Use test_epoch_end instead.'
warnings.warn(m, DeprecationWarning)
elif self.is_overriden('validation_end'):
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)
elif self.is_overriden('validation_end', model=model):
eval_results = model.validation_end(outputs)
m = 'validation_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
'Use validation_epoch_end instead.'
warnings.warn(m, DeprecationWarning)
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)
if test_mode and self.is_overriden('test_epoch_end', model=model):
eval_results = model.test_epoch_end(outputs)
elif self.is_overriden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)
# enable train mode again
model.train()

View File

@ -11,13 +11,17 @@ class TrainerModelHooksMixin(ABC):
f_op = getattr(model, f_name, None)
return callable(f_op)
def is_overriden(self, f_name, model=None):
def is_overriden(self, method_name: str, model: LightningModule = None) -> bool:
if model is None:
model = self.get_model()
super_object = LightningModule
if not hasattr(model, method_name):
# in case of calling deprecated method
return False
# when code pointers are different, it was overriden
is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
return is_overriden
def has_arg(self, f_name, arg_name):

View File

@ -282,7 +282,7 @@ class TrainerTrainLoopMixin(ABC):
def train(self):
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
' but will start from "0" in v0.8.0.', RuntimeWarning)
# get model
model = self.get_model()

View File

@ -138,13 +138,13 @@ def test_adding_step_key(tmpdir):
return decorated
model, hparams = tutils.get_model()
model.validation_end = _validation_end
model.validation_epoch_end = _validation_end
trainer_options = dict(
max_epochs=4,
default_save_path=tmpdir,
train_percent_check=0.001,
val_percent_check=0.01,
num_sanity_val_steps=0
num_sanity_val_steps=0,
)
trainer = Trainer(**trainer_options)
trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics)

View File

@ -225,8 +225,7 @@ class TestModelBase(LightningModule):
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
# training params (opt)
parser.opt_list('--learning_rate', default=0.001 * 8, type=float,
options=[0.0001, 0.0005, 0.001, 0.005],
tunable=False)
options=[0.0001, 0.0005, 0.001, 0.005], tunable=False)
parser.opt_list('--optimizer_name', default='adam', type=str,
options=['adam'], tunable=False)
# if using 2 nodes with 4 gpus each the batch size here

View File

@ -2,8 +2,11 @@
from pytorch_lightning import Trainer
import tests.models.utils as tutils
from tests.models import TestModelBase, LightTrainDataloader, LightEmptyTestStep
def test_to_be_removed_in_v0_8_0_module_imports():
def test_tbd_remove_in_v0_8_0_module_imports():
from pytorch_lightning.logging.comet_logger import CometLogger # noqa: F811
from pytorch_lightning.logging.mlflow_logger import MLFlowLogger # noqa: F811
from pytorch_lightning.logging.test_tube_logger import TestTubeLogger # noqa: F811
@ -24,7 +27,7 @@ def test_to_be_removed_in_v0_8_0_module_imports():
from pytorch_lightning.root_module.root_module import LightningModule # noqa: F811
def test_to_be_removed_in_v0_8_0_trainer():
def test_tbd_remove_in_v0_8_0_trainer():
mapping_old_new = {
'gradient_clip': 'gradient_clip_val',
'nb_gpu_nodes': 'num_nodes',
@ -45,7 +48,7 @@ def test_to_be_removed_in_v0_8_0_trainer():
'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new)
def test_to_be_removed_in_v0_9_0_module_imports():
def test_tbd_remove_in_v0_9_0_module_imports():
from pytorch_lightning.core.decorators import data_loader # noqa: F811
from pytorch_lightning.logging.comet import CometLogger # noqa: F402
@ -53,3 +56,55 @@ def test_to_be_removed_in_v0_9_0_module_imports():
from pytorch_lightning.logging.neptune import NeptuneLogger # noqa: F402
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402
class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self._dataloader(train=False)
def validation_end(self, outputs):
return {'val_loss': 0.6}
def test_end(self, outputs):
return {'test_loss': 0.6}
class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self._dataloader(train=False)
def validation_end(self, outputs):
return {'val_loss': 0.7}
def test_end(self, outputs):
return {'test_loss': 0.7}
def test_tbd_remove_in_v1_0_0_model_hooks():
hparams = tutils.get_hparams()
model = ModelVer0_6(hparams)
trainer = Trainer(logger=False)
trainer.test(model)
assert trainer.callback_metrics == {'test_loss': 0.6}
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': 0.6}
model = ModelVer0_7(hparams)
trainer = Trainer(logger=False)
trainer.test(model)
assert trainer.callback_metrics == {'test_loss': 0.7}
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': 0.7}