Prune EvalModelTemplate (#11153)

This commit is contained in:
Rohit Gupta 2021-12-19 18:38:43 +05:30 committed by GitHub
parent f95976d602
commit 61eb6230c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 9 additions and 793 deletions

View File

@ -1,3 +0,0 @@
"""Models for testing."""
from tests.base.model_template import EvalModelTemplate, GenericEvalModelTemplate # noqa: F401

View File

@ -1,57 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from torch import optim
class ConfigureOptimizersPool(ABC):
def configure_optimizers(self):
"""return whatever optimizers we want here.
:return: list of optimizers
"""
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def configure_optimizers__adagrad(self):
optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate)
return optimizer
def configure_optimizers__single_scheduler(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]
def configure_optimizers__multiple_schedulers(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]
def configure_optimizers__param_groups(self):
param_groups = [
{"params": list(self.parameters())[:2], "lr": self.learning_rate * 0.1},
{"params": list(self.parameters())[2:], "lr": self.learning_rate},
]
optimizer = optim.Adam(param_groups)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]
def configure_optimizers__lr_from_hparams(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer

View File

@ -1,178 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Generic, TypeVar
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.core.lightning import LightningModule
from tests import _PATH_DATASETS
from tests.base.model_optimizers import ConfigureOptimizersPool
from tests.base.model_test_dataloaders import TestDataloaderVariations
from tests.base.model_test_epoch_ends import TestEpochEndVariations
from tests.base.model_test_steps import TestStepVariations
from tests.base.model_train_dataloaders import TrainDataloaderVariations
from tests.base.model_train_steps import TrainingStepVariations
from tests.base.model_utilities import ModelTemplateData, ModelTemplateUtils
from tests.base.model_valid_dataloaders import ValDataloaderVariations
from tests.base.model_valid_epoch_ends import ValidationEpochEndVariations
from tests.base.model_valid_steps import ValidationStepVariations
from tests.helpers.datasets import TrialMNIST
class EvalModelTemplate(
ModelTemplateData,
ModelTemplateUtils,
TrainingStepVariations,
ValidationStepVariations,
ValidationEpochEndVariations,
TestStepVariations,
TestEpochEndVariations,
TrainDataloaderVariations,
ValDataloaderVariations,
TestDataloaderVariations,
ConfigureOptimizersPool,
LightningModule,
):
"""This template houses all combinations of model configurations we want to test.
>>> model = EvalModelTemplate()
"""
def __init__(
self,
drop_prob: float = 0.2,
batch_size: int = 32,
in_features: int = 28 * 28,
learning_rate: float = 0.001 * 8,
optimizer_name: str = "adam",
data_root: str = _PATH_DATASETS,
out_features: int = 10,
hidden_dim: int = 1000,
b1: float = 0.5,
b2: float = 0.999,
):
# init superclass
super().__init__()
self.save_hyperparameters()
self.drop_prob = drop_prob
self.batch_size = batch_size
self.in_features = in_features
self.learning_rate = learning_rate
self.optimizer_name = optimizer_name
self.data_root = data_root
self.out_features = out_features
self.hidden_dim = hidden_dim
self.b1 = b1
self.b2 = b2
self.training_step_called = False
self.training_step_end_called = False
self.training_epoch_end_called = False
self.validation_step_called = False
self.validation_step_end_called = False
self.validation_epoch_end_called = False
self.test_step_called = False
self.test_step_end_called = False
self.test_epoch_end_called = False
self.example_input_array = torch.rand(5, 28 * 28)
# build model
self.__build_model()
def __build_model(self):
"""
Simple model for testing
:return:
"""
self.c_d1 = nn.Linear(in_features=self.in_features, out_features=self.hidden_dim)
self.c_d1_bn = nn.BatchNorm1d(self.hidden_dim)
self.c_d1_drop = nn.Dropout(self.drop_prob)
self.c_d2 = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
def forward(self, x):
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
logits = F.softmax(x, dim=1)
return logits
def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll
def prepare_data(self):
TrialMNIST(root=self.data_root, train=True, download=True)
@staticmethod
def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> dict:
args = dict(
drop_prob=0.2,
batch_size=32,
in_features=28 * 28,
learning_rate=0.001 * 8,
optimizer_name="adam",
data_root=_PATH_DATASETS,
out_features=10,
hidden_dim=1000,
b1=0.5,
b2=0.999,
)
if continue_training:
args.update(test_tube_do_checkpoint_load=True, hpc_exp_number=hpc_exp_number)
return args
T = TypeVar("T")
class GenericParentEvalModelTemplate(Generic[T], EvalModelTemplate):
def __init__(
self,
drop_prob: float,
batch_size: int,
in_features: int,
learning_rate: float,
optimizer_name: str,
data_root: str,
out_features: int,
hidden_dim: int,
b1: float,
b2: float,
):
super().__init__(
drop_prob=drop_prob,
batch_size=batch_size,
in_features=in_features,
learning_rate=learning_rate,
optimizer_name=optimizer_name,
data_root=data_root,
out_features=out_features,
hidden_dim=hidden_dim,
b1=b1,
b2=b2,
)
class GenericEvalModelTemplate(GenericParentEvalModelTemplate[int]):
pass

View File

@ -1,36 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
class TestDataloaderVariations(ABC):
@abstractmethod
def dataloader(self, *args, **kwargs):
"""placeholder."""
def test_dataloader(self):
return self.dataloader(train=False)
def test_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))
def test_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))
def test_dataloader__multiple_mixed_length(self):
lengths = [50, 30, 40]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders

View File

@ -1,89 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
import torch
from pytorch_lightning.utilities import _StrategyType
class TestEpochEndVariations(ABC):
def test_epoch_end(self, outputs):
"""Called at the end of test epoch to aggregate outputs.
:param outputs: list of individual outputs of each validation step
:return:
"""
# if returned a scalar from test_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
test_loss_mean = 0
test_acc_mean = 0
for output in outputs:
test_loss = self.get_output_metric(output, "test_loss")
# reduce manually when using dp
if self.trainer._distrib_type == _StrategyType.DP:
test_loss = torch.mean(test_loss)
test_loss_mean += test_loss
# reduce manually when using dp
test_acc = self.get_output_metric(output, "test_acc")
if self.trainer._distrib_type == _StrategyType.DP:
test_acc = torch.mean(test_acc)
test_acc_mean += test_acc
test_loss_mean /= len(outputs)
test_acc_mean /= len(outputs)
metrics_dict = {"test_loss": test_loss_mean, "test_acc": test_acc_mean}
result = {"progress_bar": metrics_dict, "log": metrics_dict}
return result
def test_epoch_end__multiple_dataloaders(self, outputs):
"""Called at the end of test epoch to aggregate outputs.
:param outputs: list of individual outputs of each validation step
:return:
"""
# if returned a scalar from test_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
test_loss_mean = 0
test_acc_mean = 0
i = 0
for dl_output in outputs:
for output in dl_output:
test_loss = output["test_loss"]
# reduce manually when using dp
if self.trainer._distrib_type == _StrategyType.DP:
test_loss = torch.mean(test_loss)
test_loss_mean += test_loss
# reduce manually when using dp
test_acc = output["test_acc"]
if self.trainer._distrib_type == _StrategyType.DP:
test_acc = torch.mean(test_acc)
test_acc_mean += test_acc
i += 1
test_loss_mean /= i
test_acc_mean /= i
tqdm_dict = {"test_loss": test_loss_mean, "test_acc": test_acc_mean}
result = {"progress_bar": tqdm_dict}
return result

View File

@ -1,90 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from collections import OrderedDict
import torch
class TestStepVariations(ABC):
"""Houses all variations of test steps."""
def test_step(self, batch, batch_idx, *args, **kwargs):
"""Default, baseline test_step.
:param batch:
:return:
"""
self.test_step_called = True
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({"test_loss": loss_test, "test_acc": test_acc})
return output
if batch_idx % 2 == 0:
return test_acc
if batch_idx % 3 == 0:
output = OrderedDict(
{"test_loss": loss_test, "test_acc": test_acc, "test_dic": dict(test_loss_a=loss_test)}
)
return output
def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""Default, baseline test_step.
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_test = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)
test_acc = test_acc.type_as(x)
# alternate possible outputs to test
if batch_idx % 1 == 0:
output = OrderedDict({"test_loss": loss_test, "test_acc": test_acc})
return output
if batch_idx % 2 == 0:
return test_acc
if batch_idx % 3 == 0:
output = OrderedDict(
{"test_loss": loss_test, "test_acc": test_acc, "test_dic": dict(test_loss_a=loss_test)}
)
return output
if batch_idx % 5 == 0:
output = OrderedDict({f"test_loss_{dataloader_idx}": loss_test, f"test_acc_{dataloader_idx}": test_acc})
return output

View File

@ -1,51 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
class TrainDataloaderVariations(ABC):
@abstractmethod
def dataloader(self, train: bool, *args, **kwargs):
"""placeholder."""
def train_dataloader(self):
return self.dataloader(train=True)
def train_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=True))
def train_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=True))
def train_dataloader__zero_length(self):
dataloader = self.dataloader(train=True)
dataloader.dataset.data = dataloader.dataset.data[:0]
dataloader.dataset.targets = dataloader.dataset.targets[:0]
return dataloader
def train_dataloader__multiple_mapping(self):
"""Return a mapping loaders with different lengths."""
# List[DataLoader]
loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)]
loaders_c_d_e = [
self.dataloader(num_samples=50, train=True),
self.dataloader(num_samples=50, train=True),
self.dataloader(num_samples=50, train=True),
]
# Dict[str, List[DataLoader]]
loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e}
return loaders

View File

@ -1,50 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
class TrainingStepVariations(ABC):
"""Houses all variations of training steps."""
def training_step(self, batch, batch_idx, optimizer_idx=None):
"""Lightning calls this inside the training loop."""
self.training_step_called = True
# forward pass
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
# calculate loss
loss_train = self.loss(y, y_hat)
return {"loss": loss_train}
def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None):
"""Training step for multiple train loaders."""
assert isinstance(batch, dict)
assert len(batch) == 2
assert "a_b" in batch and "c_d_e" in batch, batch.keys()
assert isinstance(batch["a_b"], list) and len(batch["a_b"]) == 2
assert isinstance(batch["c_d_e"], list) and len(batch["c_d_e"]) == 3
# forward pass
x, y = batch["a_b"][0]
x = x.view(x.size(0), -1)
y_hat = self(x)
# calculate loss
loss_val = self.loss(y, y_hat)
return {"loss": loss_val}

View File

@ -1,33 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.utils.data import DataLoader
from tests.helpers.datasets import TrialMNIST
class ModelTemplateData:
def dataloader(self, train: bool, num_samples: int = 100):
dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True)
loader = DataLoader(dataset=dataset, batch_size=self.batch_size, num_workers=0, shuffle=train)
return loader
class ModelTemplateUtils:
def get_output_metric(self, output, name):
if isinstance(output, dict):
val = output[name]
else: # if it is 2level deep -> per dataloader and per batch
val = sum(out[name] for out in output) / len(output)
return val

View File

@ -1,39 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
class ValDataloaderVariations(ABC):
@abstractmethod
def dataloader(self, *args, **kwargs):
"""placeholder."""
def val_dataloader(self):
return self.dataloader(train=False)
def val_dataloader__multiple_mixed_length(self):
lengths = [100, 30]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders
def val_dataloader__multiple(self):
return [self.dataloader(train=False), self.dataloader(train=False)]
def val_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))
def val_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))

View File

@ -1,73 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
import torch
class ValidationEpochEndVariations(ABC):
"""Houses all variations of validation_epoch_end steps."""
def validation_epoch_end(self, outputs):
"""Called at the end of validation to aggregate outputs.
Args:
outputs: list of individual outputs of each validation step
"""
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
def _mean(res, key):
# recursive mean for multilevel dicts
return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean()
val_loss_mean = _mean(outputs, "val_loss")
val_acc_mean = _mean(outputs, "val_acc")
# alternate between tensor and scalar
if self.current_epoch % 2 == 0:
val_loss_mean = val_loss_mean.item()
val_acc_mean = val_acc_mean.item()
self.log("early_stop_on", val_loss_mean, prog_bar=True)
self.log("val_acc", val_acc_mean, prog_bar=True)
def validation_epoch_end__multiple_dataloaders(self, outputs):
"""Called at the end of validation to aggregate outputs.
Args:
outputs: list of individual outputs of each validation step
"""
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
def _mean(res, key):
return torch.stack([x[key] for x in res]).mean()
pbar = {}
logs = {}
for dl_output_list in outputs:
output_keys = dl_output_list[0].keys()
output_keys = [x for x in output_keys if "val_" in x]
for key in output_keys:
metric_out = _mean(dl_output_list, key)
pbar[key] = metric_out
logs[key] = metric_out
results = {
"val_loss": torch.stack([v for k, v in pbar.items() if k.startswith("val_loss")]).mean(),
"progress_bar": pbar,
"log": logs,
}
return results

View File

@ -1,80 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from collections import OrderedDict
import torch
class ValidationStepVariations(ABC):
"""Houses all variations of validation steps."""
def validation_step(self, batch, batch_idx, *args, **kwargs):
"""Lightning calls this inside the validation loop.
:param batch:
:return:
"""
self.validation_step_called = True
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)
output = OrderedDict({"val_loss": loss_val, "val_acc": val_acc, "test_dic": dict(val_loss_a=loss_val)})
return output
def validation_step__dp(self, batch, batch_idx, *args, **kwargs):
self.validation_step_called = True
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x.to(self.device))
y = y.to(y_hat.device)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)
self.log("val_loss", loss_val)
self.log("val_acc", val_acc)
return loss_val
def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""Lightning calls this inside the validation loop.
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)
output = OrderedDict({f"val_loss_{dataloader_idx}": loss_val, f"val_acc_{dataloader_idx}": val_acc})
return output

View File

@ -41,7 +41,7 @@ def run_model_test_without_loggers(
if not isinstance(model2, BoringModel):
for dataloader in test_loaders:
run_prediction_eval_model_template(model2, dataloader, min_acc=min_acc)
run_model_prediction(model2, dataloader, min_acc=min_acc)
def run_model_test(
@ -79,7 +79,7 @@ def run_model_test(
if not isinstance(model, BoringModel):
for dataloader in test_loaders:
run_prediction_eval_model_template(model, dataloader, min_acc=min_acc)
run_model_prediction(model, dataloader, min_acc=min_acc)
if with_hpc:
if trainer._distrib_type in (_StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2):
@ -96,7 +96,7 @@ def run_model_test(
@torch.no_grad()
def run_prediction_eval_model_template(trained_model, dataloader, min_acc=0.50):
def run_model_prediction(trained_model, dataloader, min_acc=0.50):
orig_device = trained_model.device
# run prediction on 1 batch
trained_model.cpu()

View File

@ -21,9 +21,9 @@ import pytest
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
from pytorch_lightning.loggers import TensorBoardLogger
from tests import _TEMP_PATH, RANDOM_PORTS
from tests.base.model_template import EvalModelTemplate
from tests.helpers.boring_model import BoringModel
def get_default_logger(save_dir, version=None):
@ -38,11 +38,6 @@ def get_data_path(expt_logger, path_dir=None):
# each logger has to have these attributes
name, version = expt_logger.name, expt_logger.version
# only the test-tube experiment has such attribute
if isinstance(expt_logger, TestTubeLogger):
expt = expt_logger.experiment if hasattr(expt_logger, "experiment") else expt_logger
return expt.get_data_path(name, version)
# the other experiments...
if not path_dir:
if hasattr(expt_logger, "save_dir") and expt_logger.save_dir:
@ -57,7 +52,7 @@ def get_data_path(expt_logger, path_dir=None):
return path_expt
def load_model_from_checkpoint(logger, root_weights_dir, module_class=EvalModelTemplate):
def load_model_from_checkpoint(logger, root_weights_dir, module_class=BoringModel):
trained_model = module_class.load_from_checkpoint(root_weights_dir)
assert trained_model is not None, "loading model failed"
return trained_model

View File

@ -346,7 +346,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
dataloaders = [dataloaders]
for dataloader in dataloaders:
tpipes.run_prediction_eval_model_template(pretrained_model, dataloader)
tpipes.run_model_prediction(pretrained_model, dataloader)
@RunIf(min_gpus=2)
@ -394,7 +394,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
dataloaders = [dataloaders]
for dataloader in dataloaders:
tpipes.run_prediction_eval_model_template(pretrained_model, dataloader, min_acc=0.1)
tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
def test_running_test_pretrained_model_cpu(tmpdir):
@ -540,7 +540,7 @@ def test_dp_resume(tmpdir):
new_trainer.state.stage = RunningStage.VALIDATING
dataloader = dm.train_dataloader()
tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader)
tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)
self.on_pretrain_routine_end_called = True
# new model