lightning/tests/test_deprecated.py

123 lines
4.4 KiB
Python

"""Test deprecated functionality which will be removed in vX.Y.Z"""
import random
import sys
import pytest
import torch
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
def _soft_unimport_module(str_module):
# once the module is imported e.g with parsing with pytest it lives in memory
if str_module in sys.modules:
del sys.modules[str_module]
def test_tbd_remove_in_v0_10_0_trainer():
rnd_val = random.random()
with pytest.deprecated_call(match='will be removed in v0.10.0'):
trainer = Trainer(overfit_pct=rnd_val)
assert trainer.overfit_batches == rnd_val
with pytest.deprecated_call(match='will be removed in v0.10.0'):
assert trainer.overfit_pct == rnd_val
rnd_val = random.random()
with pytest.deprecated_call(match='will be removed in v0.10.0'):
trainer = Trainer(train_percent_check=rnd_val)
assert trainer.limit_train_batches == rnd_val
with pytest.deprecated_call(match='v0.10.0'):
assert trainer.train_percent_check == rnd_val
rnd_val = random.random()
with pytest.deprecated_call(match='will be removed in v0.10.0'):
trainer = Trainer(val_percent_check=rnd_val)
assert trainer.limit_val_batches == rnd_val
with pytest.deprecated_call(match='will be removed in v0.10.0'):
assert trainer.val_percent_check == rnd_val
rnd_val = random.random()
with pytest.deprecated_call(match='will be removed in v0.10.0'):
trainer = Trainer(test_percent_check=rnd_val)
assert trainer.limit_test_batches == rnd_val
with pytest.deprecated_call(match='will be removed in v0.10.0'):
assert trainer.test_percent_check == rnd_val
trainer = Trainer()
with pytest.deprecated_call(match='will be removed in v0.10.0'):
trainer.proc_rank = 0
with pytest.deprecated_call(match='will be removed in v0.10.0'):
assert trainer.proc_rank == trainer.global_rank
with pytest.deprecated_call(match='will be removed in v0.10.0'):
trainer.ckpt_path = 'foo'
assert trainer.ckpt_path == trainer.weights_save_path == 'foo'
class ModelVer0_6(EvalModelTemplate):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self.dataloader(train=False)
def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': torch.tensor(0.6)}
def validation_end(self, outputs):
return {'val_loss': torch.tensor(0.6)}
def test_dataloader(self):
return self.dataloader(train=False)
def test_end(self, outputs):
return {'test_loss': torch.tensor(0.6)}
class ModelVer0_7(EvalModelTemplate):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self.dataloader(train=False)
def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': torch.tensor(0.7)}
def validation_end(self, outputs):
return {'val_loss': torch.tensor(0.7)}
def test_dataloader(self):
return self.dataloader(train=False)
def test_end(self, outputs):
return {'test_loss': torch.tensor(0.7)}
# def test_tbd_remove_in_v1_0_0_model_hooks():
#
# model = ModelVer0_6()
#
# with pytest.deprecated_call(match='will be removed in v1.0. Use `test_epoch_end` instead'):
# trainer = Trainer(logger=False)
# trainer.test(model)
# assert trainer.callback_metrics == {'test_loss': torch.tensor(0.6)}
#
# with pytest.deprecated_call(match='will be removed in v1.0. Use `validation_epoch_end` instead'):
# trainer = Trainer(logger=False)
# # TODO: why `dataloder` is required if it is not used
# result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
# assert result[0] == {'val_loss': torch.tensor(0.6)}
#
# model = ModelVer0_7()
#
# with pytest.deprecated_call(match='will be removed in v1.0. Use `test_epoch_end` instead'):
# trainer = Trainer(logger=False)
# trainer.test(model)
# assert trainer.callback_metrics == {'test_loss': torch.tensor(0.7)}
#
# with pytest.deprecated_call(match='will be removed in v1.0. Use `validation_epoch_end` instead'):
# trainer = Trainer(logger=False)
# # TODO: why `dataloder` is required if it is not used
# result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
# assert result[0] == {'val_loss': torch.tensor(0.7)}