192 lines
5.9 KiB
Python
192 lines
5.9 KiB
Python
import os
|
|
import traceback
|
|
from argparse import ArgumentParser
|
|
from typing import Callable, Literal, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
from lightning import LightningModule, Trainer
|
|
from lightning.data import LightningDataset
|
|
from lightning.pytorch.utilities.model_helpers import get_torchvision_model
|
|
from torch.utils.data import Dataset
|
|
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--workers", default=4, type=int)
|
|
parser.add_argument("--batchsize", default=56, type=int)
|
|
parser.add_argument("-e", "--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set")
|
|
args = parser.parse_args()
|
|
|
|
# --------------------------------
|
|
# Step 1: Define a LightningModule
|
|
# --------------------------------
|
|
|
|
|
|
class ImageNetLightningModel(LightningModule):
|
|
"""
|
|
>>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
|
ImageNetLightningModel(
|
|
(model): ResNet(...)
|
|
)
|
|
"""
|
|
|
|
from torchvision.models.resnet import ResNet18_Weights
|
|
|
|
def __init__(
|
|
self,
|
|
data_path: str,
|
|
index_file_path: str = None,
|
|
arch: str = "resnet18",
|
|
weights=ResNet18_Weights.IMAGENET1K_V1,
|
|
lr: float = 1e-4,
|
|
momentum: float = 0.9,
|
|
weight_decay: float = 1e-4,
|
|
batch_size: int = 256,
|
|
workers: int = 4,
|
|
):
|
|
super().__init__()
|
|
self.arch = arch
|
|
self.weights = weights
|
|
self.lr = lr
|
|
self.momentum = momentum
|
|
self.weight_decay = weight_decay
|
|
self.batch_size = batch_size
|
|
self.workers = workers
|
|
self.data_path = data_path
|
|
self.index_file_path = index_file_path
|
|
self.model = get_torchvision_model(self.arch, weights=self.weights)
|
|
self.train_dataset: Optional[Dataset] = None
|
|
self.eval_dataset: Optional[Dataset] = None
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
images, target = batch
|
|
output = self.model(images)
|
|
loss_train = F.cross_entropy(output, target)
|
|
self.log("train_loss", loss_train)
|
|
return loss_train
|
|
|
|
def eval_step(self, batch, batch_idx, prefix: str):
|
|
images, target = batch
|
|
output = self.model(images)
|
|
loss_val = F.cross_entropy(output, target)
|
|
self.log(f"{prefix}_loss", loss_val)
|
|
return loss_val
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
return self.eval_step(batch, batch_idx, "val")
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
return self.eval_step(batch, batch_idx, "test")
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
|
|
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch // 30))
|
|
return [optimizer], [scheduler]
|
|
|
|
def train_dataloader(self):
|
|
import torchvision as tv
|
|
|
|
transforms = tv.transforms.Compose([tv.transforms.RandomResizedCrop(224), tv.transforms.ToTensor()])
|
|
|
|
train_dataset = S3LightningImagenetDataset(
|
|
data_source=self.data_path, split="train", transforms=transforms, path_to_index_file=self.index_file_path
|
|
)
|
|
|
|
return torch.utils.data.DataLoader(
|
|
dataset=train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
import torchvision as tv
|
|
|
|
transforms = tv.transforms.Compose([tv.transforms.RandomResizedCrop(224), tv.transforms.ToTensor()])
|
|
|
|
val_dataset = S3LightningImagenetDataset(
|
|
data_source=self.data_path, split="val", transforms=transforms, path_to_index_file=self.index_file_path
|
|
)
|
|
|
|
return torch.utils.data.DataLoader(
|
|
dataset=val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers
|
|
)
|
|
|
|
def test_dataloader(self):
|
|
return self.val_dataloader()
|
|
|
|
|
|
# -------------------
|
|
# Step 2: Define data
|
|
# -------------------
|
|
|
|
|
|
class S3LightningImagenetDataset(LightningDataset):
|
|
def __init__(
|
|
self,
|
|
data_source: str,
|
|
split: Literal["train", "val"],
|
|
transforms: Optional[Callable] = None,
|
|
path_to_index_file: Optional[str] = None,
|
|
):
|
|
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
|
|
|
super().__init__(data_source=data_source, backend="s3", path_to_index_file=path_to_index_file)
|
|
|
|
# only get files for the split
|
|
self.files = tuple([x for x in self.files if split in x])
|
|
|
|
# get unique classes
|
|
self.classes = _IMAGENET_CATEGORIES
|
|
|
|
self.transforms = transforms
|
|
|
|
def load_sample(self, file_path, stream):
|
|
from PIL import Image
|
|
|
|
try:
|
|
img = Image.open(stream)
|
|
|
|
if self.transforms is not None:
|
|
img = self.transforms(img)
|
|
|
|
# Converting grey scale images to RGB
|
|
if img.shape[0] == 1:
|
|
img = img.repeat((3, 1, 1))
|
|
|
|
curr_cls = os.path.basename(os.path.dirname(file_path)).replace("_", " ")
|
|
cls_idx = self.classes.index(curr_cls)
|
|
return img, cls_idx
|
|
except Exception:
|
|
print(file_path, traceback.print_exc())
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# os.environ["AWS_ACCESS_KEY"] = <your aws access key>
|
|
# os.environ["AWS_SECRET_ACCESS_KEY"] = <your aws secret key>
|
|
|
|
data_path = "s3://imagenet-tiny"
|
|
index_file_path = "imagenet/imagenet-index.txt"
|
|
|
|
# -------------------
|
|
# Step 3: Train
|
|
# -------------------
|
|
|
|
print("Instantiate Model")
|
|
model = ImageNetLightningModel(
|
|
weights=None,
|
|
data_path=data_path,
|
|
index_file_path=index_file_path,
|
|
batch_size=args.batchsize,
|
|
workers=args.workers,
|
|
)
|
|
trainer = Trainer()
|
|
|
|
print("Train Model")
|
|
if args.evaluate:
|
|
trainer.test(model)
|
|
else:
|
|
trainer.fit(model)
|