lightning/tests/parity_pytorch/models.py

63 lines
2.3 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.pytorch.utilities.model_helpers import get_torchvision_model
from tests_pytorch import _PATH_DATASETS
if _TORCHVISION_AVAILABLE:
from torchvision import transforms
from torchvision.datasets import CIFAR10
class ParityModuleCIFAR(LightningModule):
def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, weights="DEFAULT"):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.num_classes = 10
self.backbone = get_torchvision_model(backbone, weights=weights)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(1000, hidden_dim), torch.nn.Linear(hidden_dim, self.num_classes)
)
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
self._loss = [] # needed for checking if the loss is the same as vanilla torch
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
y_hat = self.classifier(y_hat)
loss = F.cross_entropy(y_hat, y)
self._loss.append(loss.item())
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
def train_dataloader(self):
return DataLoader(
CIFAR10(root=_PATH_DATASETS, train=True, download=True, transform=self.transform),
batch_size=32,
num_workers=1,
)