lightning/tests/test_deprecated.py

158 lines
5.7 KiB
Python
Raw Normal View History

"""Test deprecated functionality which will be removed in vX.Y.Z"""
import random
import sys
import pytest
import torch
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
def _soft_unimport_module(str_module):
# once the module is imported e.g with parsing with pytest it lives in memory
if str_module in sys.modules:
del sys.modules[str_module]
def test_tbd_remove_in_v0_10_0_trainer():
rnd_val = random.random()
with pytest.deprecated_call(match='v0.10.0'):
trainer = Trainer(overfit_pct=rnd_val)
assert trainer.overfit_batches == rnd_val
with pytest.deprecated_call(match='v0.10.0'):
assert trainer.overfit_pct == rnd_val
rnd_val = random.random()
with pytest.deprecated_call(match='v0.10.0'):
trainer = Trainer(train_percent_check=rnd_val)
assert trainer.limit_train_batches == rnd_val
with pytest.deprecated_call(match='v0.10.0'):
assert trainer.train_percent_check == rnd_val
rnd_val = random.random()
with pytest.deprecated_call(match='v0.10.0'):
trainer = Trainer(val_percent_check=rnd_val)
assert trainer.limit_val_batches == rnd_val
with pytest.deprecated_call(match='v0.10.0'):
assert trainer.val_percent_check == rnd_val
rnd_val = random.random()
with pytest.deprecated_call(match='v0.10.0'):
trainer = Trainer(test_percent_check=rnd_val)
assert trainer.limit_test_batches == rnd_val
with pytest.deprecated_call(match='v0.10.0'):
assert trainer.test_percent_check == rnd_val
trainer = Trainer()
with pytest.deprecated_call(match='v0.10.0'):
trainer.proc_rank = 0
with pytest.deprecated_call(match='v0.10.0'):
assert trainer.proc_rank == trainer.global_rank
def test_tbd_remove_in_v0_9_0_trainer():
# test show_progress_bar set by progress_bar_refresh_rate
with pytest.deprecated_call(match='v0.9.0'):
trainer = Trainer(progress_bar_refresh_rate=0, show_progress_bar=True)
assert not getattr(trainer, 'show_progress_bar')
with pytest.deprecated_call(match='v0.9.0'):
trainer = Trainer(progress_bar_refresh_rate=50, show_progress_bar=False)
assert getattr(trainer, 'show_progress_bar')
with pytest.deprecated_call(match='v0.9.0'):
trainer = Trainer(num_tpu_cores=8)
assert trainer.tpu_cores == 8
def test_tbd_remove_in_v0_9_0_module_imports():
_soft_unimport_module("pytorch_lightning.core.decorators")
with pytest.deprecated_call(match='v0.9.0'):
from pytorch_lightning.core.decorators import data_loader # noqa: F811
data_loader(print)
_soft_unimport_module("pytorch_lightning.logging.comet")
with pytest.deprecated_call(match='v0.9.0'):
from pytorch_lightning.logging.comet import CometLogger # noqa: F402
_soft_unimport_module("pytorch_lightning.logging.mlflow")
with pytest.deprecated_call(match='v0.9.0'):
from pytorch_lightning.logging.mlflow import MLFlowLogger # noqa: F402
_soft_unimport_module("pytorch_lightning.logging.neptune")
with pytest.deprecated_call(match='v0.9.0'):
from pytorch_lightning.logging.neptune import NeptuneLogger # noqa: F402
_soft_unimport_module("pytorch_lightning.logging.test_tube")
with pytest.deprecated_call(match='v0.9.0'):
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
_soft_unimport_module("pytorch_lightning.logging.wandb")
with pytest.deprecated_call(match='v0.9.0'):
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402
class ModelVer0_6(EvalModelTemplate):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self.dataloader(train=False)
def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': torch.tensor(0.6)}
def validation_end(self, outputs):
return {'val_loss': torch.tensor(0.6)}
def test_dataloader(self):
return self.dataloader(train=False)
def test_end(self, outputs):
return {'test_loss': torch.tensor(0.6)}
class ModelVer0_7(EvalModelTemplate):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self.dataloader(train=False)
def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': torch.tensor(0.7)}
def validation_end(self, outputs):
return {'val_loss': torch.tensor(0.7)}
def test_dataloader(self):
return self.dataloader(train=False)
def test_end(self, outputs):
return {'test_loss': torch.tensor(0.7)}
def test_tbd_remove_in_v1_0_0_model_hooks():
hparams = EvalModelTemplate.get_default_hparams()
model = ModelVer0_6(hparams)
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
trainer.test(model)
assert trainer.callback_metrics == {'test_loss': torch.tensor(0.6)}
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.6)}
model = ModelVer0_7(hparams)
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
trainer.test(model)
assert trainer.callback_metrics == {'test_loss': torch.tensor(0.7)}
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.7)}