lightning/tests/base/model_test_steps.py

205 lines
6.1 KiB
Python
Raw Normal View History

import random
from abc import ABC
from collections import OrderedDict
import torch
from pytorch_lightning import EvalResult
class TestStepVariations(ABC):
"""
Houses all variations of test steps
"""
def test_step(self, batch, batch_idx, *args, **kwargs):
"""
Default, baseline test_step
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
})
return output
if batch_idx % 2 == 0:
return test_acc
if batch_idx % 3 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {'test_loss_a': loss_test}
})
return output
def test_step_result_obj(self, batch, batch_idx, *args, **kwargs):
"""
Default, baseline test_step
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
result = EvalResult()
# alternate possible outputs to test
if batch_idx % 1 == 0:
result.log_dict({
'test_loss': loss_test,
'test_acc': test_acc,
})
return result
if batch_idx % 2 == 0:
return test_acc
if batch_idx % 3 == 0:
result.log_dict({
'test_loss': loss_test,
'test_acc': test_acc,
})
result.test_dic = {'test_loss_a': loss_test}
return result
def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Default, baseline test_step
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
})
return output
if batch_idx % 2 == 0:
return test_acc
if batch_idx % 3 == 0:
output = OrderedDict({
'test_loss': loss_test,
'test_acc': test_acc,
'test_dic': {'test_loss_a': loss_test}
})
return output
if batch_idx % 5 == 0:
output = OrderedDict({
f'test_loss_{dataloader_idx}': loss_test,
f'test_acc_{dataloader_idx}': test_acc,
})
return output
def test_step__empty(self, batch, batch_idx, *args, **kwargs):
return {}
def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
# Do regular EvalResult Logging
result = EvalResult(checkpoint_on=loss_test)
result.log('test_loss', loss_test)
result.log('test_acc', test_acc)
batch_size = x.size(0)
lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
lst_of_int = [random.randint(500, 1000) for i in range(batch_size)]
lst_of_lst = [[x] for x in lst_of_int]
lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)]
# This is passed in from pytest via parameterization
option = getattr(self, 'test_option', 0)
prediction_file = getattr(self, 'prediction_file', 'predictions.pt')
lazy_ids = torch.arange(batch_idx * self.batch_size, batch_idx * self.batch_size + x.size(0))
# Base
if option == 0:
result.write('idxs', lazy_ids, prediction_file)
result.write('preds', labels_hat, prediction_file)
# Check mismatching tensor len
elif option == 1:
result.write('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
result.write('preds', labels_hat, prediction_file)
# write multi-dimension
elif option == 2:
result.write('idxs', lazy_ids, prediction_file)
result.write('preds', labels_hat, prediction_file)
result.write('x', x, prediction_file)
# write str list
elif option == 3:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_str, prediction_file)
# write int list
elif option == 4:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_int, prediction_file)
# write nested list
elif option == 5:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_lst, prediction_file)
# write dict list
elif option == 6:
result.write('idxs', lazy_ids, prediction_file)
result.write('vals', lst_of_dict, prediction_file)
return result