split tests for deprecated api (#5071)
* imports * imports * flake8 Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
3100b7839a
commit
b50ad9ee95
|
@ -0,0 +1,21 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in vX.Y.Z"""
|
||||
import sys
|
||||
|
||||
|
||||
def _soft_unimport_module(str_module):
|
||||
# once the module is imported e.g with parsing with pytest it lives in memory
|
||||
if str_module in sys.modules:
|
||||
del sys.modules[str_module]
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in vX.Y.Z"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
def test_tbd_remove_in_v1_2_0():
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
ModelCheckpoint(filepath='..')
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
ModelCheckpoint('..')
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='inputs which are not feasible'):
|
||||
ModelCheckpoint(filepath='..', dirpath='.')
|
||||
|
||||
|
||||
def test_tbd_remove_in_v1_2_0_metrics():
|
||||
from pytorch_lightning.metrics.classification import Fbeta
|
||||
from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
Fbeta(2)
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2)
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0]))
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in vX.Y.Z"""
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from unittest import mock
|
||||
|
||||
|
@ -21,10 +20,8 @@ import torch
|
|||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.metrics.functional.classification import auc
|
||||
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
def test_tbd_remove_in_v1_3_0(tmpdir):
|
||||
|
@ -52,27 +49,27 @@ def test_tbd_remove_in_v1_3_0(tmpdir):
|
|||
|
||||
|
||||
def test_tbd_remove_in_v1_3_0_metrics():
|
||||
from pytorch_lightning.metrics.functional.classification import to_onehot
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import to_onehot
|
||||
to_onehot(torch.tensor([1, 2, 3]))
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import to_categorical
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import to_categorical
|
||||
to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]]))
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import get_num_classes
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import get_num_classes
|
||||
get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1]))
|
||||
|
||||
x_binary = torch.tensor([0, 1, 2, 3])
|
||||
y_binary = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import roc
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import roc
|
||||
roc(pred=x_binary, target=y_binary)
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import _roc
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import _roc
|
||||
_roc(pred=x_binary, target=y_binary)
|
||||
|
||||
x_multy = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
|
@ -81,64 +78,40 @@ def test_tbd_remove_in_v1_3_0_metrics():
|
|||
[0.05, 0.05, 0.05, 0.85]])
|
||||
y_multy = torch.tensor([0, 1, 3, 2])
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import multiclass_roc
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import multiclass_roc
|
||||
multiclass_roc(pred=x_multy, target=y_multy)
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import average_precision
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import average_precision
|
||||
average_precision(pred=x_binary, target=y_binary)
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import precision_recall_curve
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import precision_recall_curve
|
||||
precision_recall_curve(pred=x_binary, target=y_binary)
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve
|
||||
multiclass_precision_recall_curve(pred=x_multy, target=y_multy)
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
reduce(torch.tensor([0, 1, 1, 0]), 'sum')
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import class_reduce
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
from pytorch_lightning.metrics.functional.reduction import class_reduce
|
||||
class_reduce(torch.randint(1, 10, (50,)).float(),
|
||||
torch.randint(10, 20, (50,)).float(),
|
||||
torch.randint(1, 100, (50,)).float())
|
||||
|
||||
|
||||
def test_tbd_remove_in_v1_2_0():
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
checkpoint_cb = ModelCheckpoint(filepath='.')
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
checkpoint_cb = ModelCheckpoint('.')
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='inputs which are not feasible'):
|
||||
checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.')
|
||||
|
||||
|
||||
def test_tbd_remove_in_v1_2_0_metrics():
|
||||
from pytorch_lightning.metrics.classification import Fbeta
|
||||
from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
Fbeta(2)
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2)
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v1.2'):
|
||||
f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0]))
|
||||
|
||||
|
||||
# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
|
||||
@pytest.mark.parametrize(['profiler', 'expected'], [
|
||||
(True, SimpleProfiler),
|
||||
(False, PassThroughProfiler),
|
||||
])
|
||||
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
|
||||
# remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
trainer = Trainer(profiler=profiler)
|
||||
assert isinstance(trainer.profiler, expected)
|
||||
|
@ -162,47 +135,3 @@ def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, ex
|
|||
assert getattr(args, "profiler") == expected_parsed_arg
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
assert isinstance(trainer.profiler, expected_profiler)
|
||||
|
||||
|
||||
def _soft_unimport_module(str_module):
|
||||
# once the module is imported e.g with parsing with pytest it lives in memory
|
||||
if str_module in sys.modules:
|
||||
del sys.modules[str_module]
|
||||
|
||||
|
||||
class ModelVer0_6(EvalModelTemplate):
|
||||
|
||||
# todo: this shall not be needed while evaluate asks for dataloader explicitly
|
||||
def val_dataloader(self):
|
||||
return self.dataloader(train=False)
|
||||
|
||||
def validation_step(self, batch, batch_idx, *args, **kwargs):
|
||||
return {'val_loss': torch.tensor(0.6)}
|
||||
|
||||
def validation_end(self, outputs):
|
||||
return {'val_loss': torch.tensor(0.6)}
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.dataloader(train=False)
|
||||
|
||||
def test_end(self, outputs):
|
||||
return {'test_loss': torch.tensor(0.6)}
|
||||
|
||||
|
||||
class ModelVer0_7(EvalModelTemplate):
|
||||
|
||||
# todo: this shall not be needed while evaluate asks for dataloader explicitly
|
||||
def val_dataloader(self):
|
||||
return self.dataloader(train=False)
|
||||
|
||||
def validation_step(self, batch, batch_idx, *args, **kwargs):
|
||||
return {'val_loss': torch.tensor(0.7)}
|
||||
|
||||
def validation_end(self, outputs):
|
||||
return {'val_loss': torch.tensor(0.7)}
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.dataloader(train=False)
|
||||
|
||||
def test_end(self, outputs):
|
||||
return {'test_loss': torch.tensor(0.7)}
|
Loading…
Reference in New Issue