lightning/tests/test_deprecated.py

51 lines
1.4 KiB
Python
Raw Normal View History

"""Test deprecated functionality which will be removed in vX.Y.Z"""
import sys
import torch
from tests.base import EvalModelTemplate
def _soft_unimport_module(str_module):
# once the module is imported e.g with parsing with pytest it lives in memory
if str_module in sys.modules:
del sys.modules[str_module]
class ModelVer0_6(EvalModelTemplate):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self.dataloader(train=False)
def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': torch.tensor(0.6)}
def validation_end(self, outputs):
return {'val_loss': torch.tensor(0.6)}
def test_dataloader(self):
return self.dataloader(train=False)
def test_end(self, outputs):
return {'test_loss': torch.tensor(0.6)}
class ModelVer0_7(EvalModelTemplate):
# todo: this shall not be needed while evaluate asks for dataloader explicitly
def val_dataloader(self):
return self.dataloader(train=False)
def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': torch.tensor(0.7)}
def validation_end(self, outputs):
return {'val_loss': torch.tensor(0.7)}
def test_dataloader(self):
return self.dataloader(train=False)
def test_end(self, outputs):
return {'test_loss': torch.tensor(0.7)}