Prune EvalModelTemplate (#11153)
This commit is contained in:
parent
f95976d602
commit
61eb6230c2
|
@ -1,3 +0,0 @@
|
|||
"""Models for testing."""
|
||||
|
||||
from tests.base.model_template import EvalModelTemplate, GenericEvalModelTemplate # noqa: F401
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
|
||||
from torch import optim
|
||||
|
||||
|
||||
class ConfigureOptimizersPool(ABC):
|
||||
def configure_optimizers(self):
|
||||
"""return whatever optimizers we want here.
|
||||
|
||||
:return: list of optimizers
|
||||
"""
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def configure_optimizers__adagrad(self):
|
||||
optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def configure_optimizers__single_scheduler(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def configure_optimizers__multiple_schedulers(self):
|
||||
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
|
||||
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
|
||||
|
||||
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]
|
||||
|
||||
def configure_optimizers__param_groups(self):
|
||||
param_groups = [
|
||||
{"params": list(self.parameters())[:2], "lr": self.learning_rate * 0.1},
|
||||
{"params": list(self.parameters())[2:], "lr": self.learning_rate},
|
||||
]
|
||||
|
||||
optimizer = optim.Adam(param_groups)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def configure_optimizers__lr_from_hparams(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
return optimizer
|
|
@ -1,178 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from tests import _PATH_DATASETS
|
||||
from tests.base.model_optimizers import ConfigureOptimizersPool
|
||||
from tests.base.model_test_dataloaders import TestDataloaderVariations
|
||||
from tests.base.model_test_epoch_ends import TestEpochEndVariations
|
||||
from tests.base.model_test_steps import TestStepVariations
|
||||
from tests.base.model_train_dataloaders import TrainDataloaderVariations
|
||||
from tests.base.model_train_steps import TrainingStepVariations
|
||||
from tests.base.model_utilities import ModelTemplateData, ModelTemplateUtils
|
||||
from tests.base.model_valid_dataloaders import ValDataloaderVariations
|
||||
from tests.base.model_valid_epoch_ends import ValidationEpochEndVariations
|
||||
from tests.base.model_valid_steps import ValidationStepVariations
|
||||
from tests.helpers.datasets import TrialMNIST
|
||||
|
||||
|
||||
class EvalModelTemplate(
|
||||
ModelTemplateData,
|
||||
ModelTemplateUtils,
|
||||
TrainingStepVariations,
|
||||
ValidationStepVariations,
|
||||
ValidationEpochEndVariations,
|
||||
TestStepVariations,
|
||||
TestEpochEndVariations,
|
||||
TrainDataloaderVariations,
|
||||
ValDataloaderVariations,
|
||||
TestDataloaderVariations,
|
||||
ConfigureOptimizersPool,
|
||||
LightningModule,
|
||||
):
|
||||
"""This template houses all combinations of model configurations we want to test.
|
||||
|
||||
>>> model = EvalModelTemplate()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
drop_prob: float = 0.2,
|
||||
batch_size: int = 32,
|
||||
in_features: int = 28 * 28,
|
||||
learning_rate: float = 0.001 * 8,
|
||||
optimizer_name: str = "adam",
|
||||
data_root: str = _PATH_DATASETS,
|
||||
out_features: int = 10,
|
||||
hidden_dim: int = 1000,
|
||||
b1: float = 0.5,
|
||||
b2: float = 0.999,
|
||||
):
|
||||
# init superclass
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.drop_prob = drop_prob
|
||||
self.batch_size = batch_size
|
||||
self.in_features = in_features
|
||||
self.learning_rate = learning_rate
|
||||
self.optimizer_name = optimizer_name
|
||||
self.data_root = data_root
|
||||
self.out_features = out_features
|
||||
self.hidden_dim = hidden_dim
|
||||
self.b1 = b1
|
||||
self.b2 = b2
|
||||
self.training_step_called = False
|
||||
self.training_step_end_called = False
|
||||
self.training_epoch_end_called = False
|
||||
self.validation_step_called = False
|
||||
self.validation_step_end_called = False
|
||||
self.validation_epoch_end_called = False
|
||||
self.test_step_called = False
|
||||
self.test_step_end_called = False
|
||||
self.test_epoch_end_called = False
|
||||
|
||||
self.example_input_array = torch.rand(5, 28 * 28)
|
||||
|
||||
# build model
|
||||
self.__build_model()
|
||||
|
||||
def __build_model(self):
|
||||
"""
|
||||
Simple model for testing
|
||||
:return:
|
||||
"""
|
||||
self.c_d1 = nn.Linear(in_features=self.in_features, out_features=self.hidden_dim)
|
||||
self.c_d1_bn = nn.BatchNorm1d(self.hidden_dim)
|
||||
self.c_d1_drop = nn.Dropout(self.drop_prob)
|
||||
|
||||
self.c_d2 = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_d1(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
x = self.c_d1_drop(x)
|
||||
|
||||
x = self.c_d2(x)
|
||||
logits = F.softmax(x, dim=1)
|
||||
|
||||
return logits
|
||||
|
||||
def loss(self, labels, logits):
|
||||
nll = F.nll_loss(logits, labels)
|
||||
return nll
|
||||
|
||||
def prepare_data(self):
|
||||
TrialMNIST(root=self.data_root, train=True, download=True)
|
||||
|
||||
@staticmethod
|
||||
def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> dict:
|
||||
args = dict(
|
||||
drop_prob=0.2,
|
||||
batch_size=32,
|
||||
in_features=28 * 28,
|
||||
learning_rate=0.001 * 8,
|
||||
optimizer_name="adam",
|
||||
data_root=_PATH_DATASETS,
|
||||
out_features=10,
|
||||
hidden_dim=1000,
|
||||
b1=0.5,
|
||||
b2=0.999,
|
||||
)
|
||||
|
||||
if continue_training:
|
||||
args.update(test_tube_do_checkpoint_load=True, hpc_exp_number=hpc_exp_number)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class GenericParentEvalModelTemplate(Generic[T], EvalModelTemplate):
|
||||
def __init__(
|
||||
self,
|
||||
drop_prob: float,
|
||||
batch_size: int,
|
||||
in_features: int,
|
||||
learning_rate: float,
|
||||
optimizer_name: str,
|
||||
data_root: str,
|
||||
out_features: int,
|
||||
hidden_dim: int,
|
||||
b1: float,
|
||||
b2: float,
|
||||
):
|
||||
super().__init__(
|
||||
drop_prob=drop_prob,
|
||||
batch_size=batch_size,
|
||||
in_features=in_features,
|
||||
learning_rate=learning_rate,
|
||||
optimizer_name=optimizer_name,
|
||||
data_root=data_root,
|
||||
out_features=out_features,
|
||||
hidden_dim=hidden_dim,
|
||||
b1=b1,
|
||||
b2=b2,
|
||||
)
|
||||
|
||||
|
||||
class GenericEvalModelTemplate(GenericParentEvalModelTemplate[int]):
|
||||
pass
|
|
@ -1,36 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
|
||||
|
||||
|
||||
class TestDataloaderVariations(ABC):
|
||||
@abstractmethod
|
||||
def dataloader(self, *args, **kwargs):
|
||||
"""placeholder."""
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.dataloader(train=False)
|
||||
|
||||
def test_dataloader__infinite(self):
|
||||
return CustomInfDataloader(self.dataloader(train=False))
|
||||
|
||||
def test_dataloader__not_implemented_error(self):
|
||||
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))
|
||||
|
||||
def test_dataloader__multiple_mixed_length(self):
|
||||
lengths = [50, 30, 40]
|
||||
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
|
||||
return dataloaders
|
|
@ -1,89 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import _StrategyType
|
||||
|
||||
|
||||
class TestEpochEndVariations(ABC):
|
||||
def test_epoch_end(self, outputs):
|
||||
"""Called at the end of test epoch to aggregate outputs.
|
||||
|
||||
:param outputs: list of individual outputs of each validation step
|
||||
:return:
|
||||
"""
|
||||
# if returned a scalar from test_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
# return torch.stack(outputs).mean()
|
||||
test_loss_mean = 0
|
||||
test_acc_mean = 0
|
||||
for output in outputs:
|
||||
test_loss = self.get_output_metric(output, "test_loss")
|
||||
|
||||
# reduce manually when using dp
|
||||
if self.trainer._distrib_type == _StrategyType.DP:
|
||||
test_loss = torch.mean(test_loss)
|
||||
test_loss_mean += test_loss
|
||||
|
||||
# reduce manually when using dp
|
||||
test_acc = self.get_output_metric(output, "test_acc")
|
||||
if self.trainer._distrib_type == _StrategyType.DP:
|
||||
test_acc = torch.mean(test_acc)
|
||||
|
||||
test_acc_mean += test_acc
|
||||
|
||||
test_loss_mean /= len(outputs)
|
||||
test_acc_mean /= len(outputs)
|
||||
|
||||
metrics_dict = {"test_loss": test_loss_mean, "test_acc": test_acc_mean}
|
||||
result = {"progress_bar": metrics_dict, "log": metrics_dict}
|
||||
return result
|
||||
|
||||
def test_epoch_end__multiple_dataloaders(self, outputs):
|
||||
"""Called at the end of test epoch to aggregate outputs.
|
||||
|
||||
:param outputs: list of individual outputs of each validation step
|
||||
:return:
|
||||
"""
|
||||
# if returned a scalar from test_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
# return torch.stack(outputs).mean()
|
||||
test_loss_mean = 0
|
||||
test_acc_mean = 0
|
||||
i = 0
|
||||
for dl_output in outputs:
|
||||
for output in dl_output:
|
||||
test_loss = output["test_loss"]
|
||||
|
||||
# reduce manually when using dp
|
||||
if self.trainer._distrib_type == _StrategyType.DP:
|
||||
test_loss = torch.mean(test_loss)
|
||||
test_loss_mean += test_loss
|
||||
|
||||
# reduce manually when using dp
|
||||
test_acc = output["test_acc"]
|
||||
if self.trainer._distrib_type == _StrategyType.DP:
|
||||
test_acc = torch.mean(test_acc)
|
||||
|
||||
test_acc_mean += test_acc
|
||||
i += 1
|
||||
|
||||
test_loss_mean /= i
|
||||
test_acc_mean /= i
|
||||
|
||||
tqdm_dict = {"test_loss": test_loss_mean, "test_acc": test_acc_mean}
|
||||
result = {"progress_bar": tqdm_dict}
|
||||
return result
|
|
@ -1,90 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestStepVariations(ABC):
|
||||
"""Houses all variations of test steps."""
|
||||
|
||||
def test_step(self, batch, batch_idx, *args, **kwargs):
|
||||
"""Default, baseline test_step.
|
||||
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
self.test_step_called = True
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_test = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
test_acc = torch.tensor(test_acc)
|
||||
|
||||
test_acc = test_acc.type_as(x)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({"test_loss": loss_test, "test_acc": test_acc})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict(
|
||||
{"test_loss": loss_test, "test_acc": test_acc, "test_dic": dict(test_loss_a=loss_test)}
|
||||
)
|
||||
return output
|
||||
|
||||
def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
|
||||
"""Default, baseline test_step.
|
||||
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_test = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
test_acc = torch.tensor(test_acc)
|
||||
|
||||
test_acc = test_acc.type_as(x)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({"test_loss": loss_test, "test_acc": test_acc})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict(
|
||||
{"test_loss": loss_test, "test_acc": test_acc, "test_dic": dict(test_loss_a=loss_test)}
|
||||
)
|
||||
return output
|
||||
if batch_idx % 5 == 0:
|
||||
output = OrderedDict({f"test_loss_{dataloader_idx}": loss_test, f"test_acc_{dataloader_idx}": test_acc})
|
||||
return output
|
|
@ -1,51 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
|
||||
|
||||
|
||||
class TrainDataloaderVariations(ABC):
|
||||
@abstractmethod
|
||||
def dataloader(self, train: bool, *args, **kwargs):
|
||||
"""placeholder."""
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataloader(train=True)
|
||||
|
||||
def train_dataloader__infinite(self):
|
||||
return CustomInfDataloader(self.dataloader(train=True))
|
||||
|
||||
def train_dataloader__not_implemented_error(self):
|
||||
return CustomNotImplementedErrorDataloader(self.dataloader(train=True))
|
||||
|
||||
def train_dataloader__zero_length(self):
|
||||
dataloader = self.dataloader(train=True)
|
||||
dataloader.dataset.data = dataloader.dataset.data[:0]
|
||||
dataloader.dataset.targets = dataloader.dataset.targets[:0]
|
||||
return dataloader
|
||||
|
||||
def train_dataloader__multiple_mapping(self):
|
||||
"""Return a mapping loaders with different lengths."""
|
||||
|
||||
# List[DataLoader]
|
||||
loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)]
|
||||
loaders_c_d_e = [
|
||||
self.dataloader(num_samples=50, train=True),
|
||||
self.dataloader(num_samples=50, train=True),
|
||||
self.dataloader(num_samples=50, train=True),
|
||||
]
|
||||
# Dict[str, List[DataLoader]]
|
||||
loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e}
|
||||
return loaders
|
|
@ -1,50 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class TrainingStepVariations(ABC):
|
||||
"""Houses all variations of training steps."""
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
"""Lightning calls this inside the training loop."""
|
||||
self.training_step_called = True
|
||||
|
||||
# forward pass
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
# calculate loss
|
||||
loss_train = self.loss(y, y_hat)
|
||||
return {"loss": loss_train}
|
||||
|
||||
def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None):
|
||||
"""Training step for multiple train loaders."""
|
||||
|
||||
assert isinstance(batch, dict)
|
||||
assert len(batch) == 2
|
||||
|
||||
assert "a_b" in batch and "c_d_e" in batch, batch.keys()
|
||||
assert isinstance(batch["a_b"], list) and len(batch["a_b"]) == 2
|
||||
assert isinstance(batch["c_d_e"], list) and len(batch["c_d_e"]) == 3
|
||||
|
||||
# forward pass
|
||||
x, y = batch["a_b"][0]
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
# calculate loss
|
||||
loss_val = self.loss(y, y_hat)
|
||||
return {"loss": loss_val}
|
|
@ -1,33 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from tests.helpers.datasets import TrialMNIST
|
||||
|
||||
|
||||
class ModelTemplateData:
|
||||
def dataloader(self, train: bool, num_samples: int = 100):
|
||||
dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True)
|
||||
|
||||
loader = DataLoader(dataset=dataset, batch_size=self.batch_size, num_workers=0, shuffle=train)
|
||||
return loader
|
||||
|
||||
|
||||
class ModelTemplateUtils:
|
||||
def get_output_metric(self, output, name):
|
||||
if isinstance(output, dict):
|
||||
val = output[name]
|
||||
else: # if it is 2level deep -> per dataloader and per batch
|
||||
val = sum(out[name] for out in output) / len(output)
|
||||
return val
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
|
||||
|
||||
|
||||
class ValDataloaderVariations(ABC):
|
||||
@abstractmethod
|
||||
def dataloader(self, *args, **kwargs):
|
||||
"""placeholder."""
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.dataloader(train=False)
|
||||
|
||||
def val_dataloader__multiple_mixed_length(self):
|
||||
lengths = [100, 30]
|
||||
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
|
||||
return dataloaders
|
||||
|
||||
def val_dataloader__multiple(self):
|
||||
return [self.dataloader(train=False), self.dataloader(train=False)]
|
||||
|
||||
def val_dataloader__infinite(self):
|
||||
return CustomInfDataloader(self.dataloader(train=False))
|
||||
|
||||
def val_dataloader__not_implemented_error(self):
|
||||
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))
|
|
@ -1,73 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ValidationEpochEndVariations(ABC):
|
||||
"""Houses all variations of validation_epoch_end steps."""
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
"""Called at the end of validation to aggregate outputs.
|
||||
|
||||
Args:
|
||||
outputs: list of individual outputs of each validation step
|
||||
"""
|
||||
|
||||
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
def _mean(res, key):
|
||||
# recursive mean for multilevel dicts
|
||||
return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean()
|
||||
|
||||
val_loss_mean = _mean(outputs, "val_loss")
|
||||
val_acc_mean = _mean(outputs, "val_acc")
|
||||
|
||||
# alternate between tensor and scalar
|
||||
if self.current_epoch % 2 == 0:
|
||||
val_loss_mean = val_loss_mean.item()
|
||||
val_acc_mean = val_acc_mean.item()
|
||||
|
||||
self.log("early_stop_on", val_loss_mean, prog_bar=True)
|
||||
self.log("val_acc", val_acc_mean, prog_bar=True)
|
||||
|
||||
def validation_epoch_end__multiple_dataloaders(self, outputs):
|
||||
"""Called at the end of validation to aggregate outputs.
|
||||
|
||||
Args:
|
||||
outputs: list of individual outputs of each validation step
|
||||
"""
|
||||
|
||||
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
def _mean(res, key):
|
||||
return torch.stack([x[key] for x in res]).mean()
|
||||
|
||||
pbar = {}
|
||||
logs = {}
|
||||
for dl_output_list in outputs:
|
||||
output_keys = dl_output_list[0].keys()
|
||||
output_keys = [x for x in output_keys if "val_" in x]
|
||||
for key in output_keys:
|
||||
metric_out = _mean(dl_output_list, key)
|
||||
pbar[key] = metric_out
|
||||
logs[key] = metric_out
|
||||
|
||||
results = {
|
||||
"val_loss": torch.stack([v for k, v in pbar.items() if k.startswith("val_loss")]).mean(),
|
||||
"progress_bar": pbar,
|
||||
"log": logs,
|
||||
}
|
||||
return results
|
|
@ -1,80 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ValidationStepVariations(ABC):
|
||||
"""Houses all variations of validation steps."""
|
||||
|
||||
def validation_step(self, batch, batch_idx, *args, **kwargs):
|
||||
"""Lightning calls this inside the validation loop.
|
||||
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
self.validation_step_called = True
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc).type_as(x)
|
||||
|
||||
output = OrderedDict({"val_loss": loss_val, "val_acc": val_acc, "test_dic": dict(val_loss_a=loss_val)})
|
||||
return output
|
||||
|
||||
def validation_step__dp(self, batch, batch_idx, *args, **kwargs):
|
||||
self.validation_step_called = True
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x.to(self.device))
|
||||
|
||||
y = y.to(y_hat.device)
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc).type_as(x)
|
||||
|
||||
self.log("val_loss", loss_val)
|
||||
self.log("val_acc", val_acc)
|
||||
return loss_val
|
||||
|
||||
def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
|
||||
"""Lightning calls this inside the validation loop.
|
||||
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x)
|
||||
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc).type_as(x)
|
||||
|
||||
output = OrderedDict({f"val_loss_{dataloader_idx}": loss_val, f"val_acc_{dataloader_idx}": val_acc})
|
||||
return output
|
|
@ -41,7 +41,7 @@ def run_model_test_without_loggers(
|
|||
|
||||
if not isinstance(model2, BoringModel):
|
||||
for dataloader in test_loaders:
|
||||
run_prediction_eval_model_template(model2, dataloader, min_acc=min_acc)
|
||||
run_model_prediction(model2, dataloader, min_acc=min_acc)
|
||||
|
||||
|
||||
def run_model_test(
|
||||
|
@ -79,7 +79,7 @@ def run_model_test(
|
|||
|
||||
if not isinstance(model, BoringModel):
|
||||
for dataloader in test_loaders:
|
||||
run_prediction_eval_model_template(model, dataloader, min_acc=min_acc)
|
||||
run_model_prediction(model, dataloader, min_acc=min_acc)
|
||||
|
||||
if with_hpc:
|
||||
if trainer._distrib_type in (_StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2):
|
||||
|
@ -96,7 +96,7 @@ def run_model_test(
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_prediction_eval_model_template(trained_model, dataloader, min_acc=0.50):
|
||||
def run_model_prediction(trained_model, dataloader, min_acc=0.50):
|
||||
orig_device = trained_model.device
|
||||
# run prediction on 1 batch
|
||||
trained_model.cpu()
|
||||
|
|
|
@ -21,9 +21,9 @@ import pytest
|
|||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests import _TEMP_PATH, RANDOM_PORTS
|
||||
from tests.base.model_template import EvalModelTemplate
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
def get_default_logger(save_dir, version=None):
|
||||
|
@ -38,11 +38,6 @@ def get_data_path(expt_logger, path_dir=None):
|
|||
# each logger has to have these attributes
|
||||
name, version = expt_logger.name, expt_logger.version
|
||||
|
||||
# only the test-tube experiment has such attribute
|
||||
if isinstance(expt_logger, TestTubeLogger):
|
||||
expt = expt_logger.experiment if hasattr(expt_logger, "experiment") else expt_logger
|
||||
return expt.get_data_path(name, version)
|
||||
|
||||
# the other experiments...
|
||||
if not path_dir:
|
||||
if hasattr(expt_logger, "save_dir") and expt_logger.save_dir:
|
||||
|
@ -57,7 +52,7 @@ def get_data_path(expt_logger, path_dir=None):
|
|||
return path_expt
|
||||
|
||||
|
||||
def load_model_from_checkpoint(logger, root_weights_dir, module_class=EvalModelTemplate):
|
||||
def load_model_from_checkpoint(logger, root_weights_dir, module_class=BoringModel):
|
||||
trained_model = module_class.load_from_checkpoint(root_weights_dir)
|
||||
assert trained_model is not None, "loading model failed"
|
||||
return trained_model
|
||||
|
|
|
@ -346,7 +346,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
|
|||
dataloaders = [dataloaders]
|
||||
|
||||
for dataloader in dataloaders:
|
||||
tpipes.run_prediction_eval_model_template(pretrained_model, dataloader)
|
||||
tpipes.run_model_prediction(pretrained_model, dataloader)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2)
|
||||
|
@ -394,7 +394,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
|
|||
dataloaders = [dataloaders]
|
||||
|
||||
for dataloader in dataloaders:
|
||||
tpipes.run_prediction_eval_model_template(pretrained_model, dataloader, min_acc=0.1)
|
||||
tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
|
||||
|
||||
|
||||
def test_running_test_pretrained_model_cpu(tmpdir):
|
||||
|
@ -540,7 +540,7 @@ def test_dp_resume(tmpdir):
|
|||
new_trainer.state.stage = RunningStage.VALIDATING
|
||||
|
||||
dataloader = dm.train_dataloader()
|
||||
tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader)
|
||||
tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)
|
||||
self.on_pretrain_routine_end_called = True
|
||||
|
||||
# new model
|
||||
|
|
Loading…
Reference in New Issue