from abc import ABC from collections import OrderedDict import torch class TestStepVariations(ABC): """ Houses all variations of test steps """ def test_step(self, batch, batch_idx, *args, **kwargs): """ Default, baseline test_step :param batch: :return: """ x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) loss_test = self.loss(y, y_hat) # acc labels_hat = torch.argmax(y_hat, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) test_acc = torch.tensor(test_acc) test_acc = test_acc.type_as(x) # alternate possible outputs to test if batch_idx % 1 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, }) return output if batch_idx % 2 == 0: return test_acc if batch_idx % 3 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, 'test_dic': {'test_loss_a': loss_test} }) return output def test_step_multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): """ Default, baseline test_step :param batch: :return: """ x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) loss_test = self.loss(y, y_hat) # acc labels_hat = torch.argmax(y_hat, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) test_acc = torch.tensor(test_acc) test_acc = test_acc.type_as(x) # alternate possible outputs to test if batch_idx % 1 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, }) return output if batch_idx % 2 == 0: return test_acc if batch_idx % 3 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, 'test_dic': {'test_loss_a': loss_test} }) return output if batch_idx % 5 == 0: output = OrderedDict({ f'test_loss_{dataloader_idx}': loss_test, f'test_acc_{dataloader_idx}': test_acc, }) return output