lightning/pl_examples/domain_templates/imagenet.py

251 lines
9.0 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py.
Before you can run this example, you will need to download the ImageNet dataset manually from the
`official website <http://image-net.org/download>`_ and place it into a folder `path/to/imagenet`.
Train on ImageNet with default parameters:
.. code-block: bash
python imagenet.py --data-path /path/to/imagenet
or show all options you can change:
.. code-block: bash
python imagenet.py --help
"""
import os
from argparse import ArgumentParser, Namespace
import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning.core import LightningModule
class ImageNetLightningModel(LightningModule):
"""
>>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
ImageNetLightningModel(
(model): ResNet(...)
)
"""
# pull out resnet names from torchvision models
MODEL_NAMES = sorted(
name
for name in models.__dict__
if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
)
def __init__(
self,
data_path: str,
arch: str = "resnet18",
pretrained: bool = False,
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
batch_size: int = 4,
workers: int = 2,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.arch = arch
self.pretrained = pretrained
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.data_path = data_path
self.batch_size = batch_size
self.workers = workers
self.model = models.__dict__[self.arch](pretrained=self.pretrained)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
images, target = batch
output = self(images)
loss_train = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
self.log("train_loss", loss_train, on_step=True, on_epoch=True, logger=True)
self.log("train_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True)
self.log("train_acc5", acc5, on_step=True, on_epoch=True, logger=True)
return loss_train
def eval_step(self, batch, batch_idx, prefix: str):
images, target = batch
output = self(images)
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
self.log(f"{prefix}_loss", loss_val, on_step=True, on_epoch=True)
self.log(f"{prefix}_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True)
self.log(f"{prefix}_acc5", acc5, on_step=True, on_epoch=True)
def validation_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "val")
@staticmethod
def __accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k."""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch // 30))
return [optimizer], [scheduler]
def train_dataloader(self):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dir = os.path.join(self.data_path, "train")
train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose(
[transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]
),
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers
)
return train_loader
def val_dataloader(self):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_dir = os.path.join(self.data_path, "val")
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(
val_dir,
transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]
),
),
batch_size=self.batch_size,
shuffle=False,
num_workers=self.workers,
)
return val_loader
def test_dataloader(self):
return self.val_dataloader()
def test_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "test")
@staticmethod
def add_model_specific_args(parent_parser): # pragma: no-cover
parser = parent_parser.add_argument_group("ImageNetLightningModel")
parser.add_argument(
"-a",
"--arch",
metavar="ARCH",
default="resnet18",
choices=ImageNetLightningModel.MODEL_NAMES,
help=("model architecture: " + " | ".join(ImageNetLightningModel.MODEL_NAMES) + " (default: resnet18)"),
)
parser.add_argument(
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument(
"-b",
"--batch-size",
default=256,
type=int,
metavar="N",
help="mini-batch size (default: 256), this is the total batch size of all GPUs on the current node"
" when using Data Parallel or Distributed Data Parallel",
)
parser.add_argument(
"--lr", "--learning-rate", default=0.1, type=float, metavar="LR", help="initial learning rate", dest="lr"
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument("--pretrained", dest="pretrained", action="store_true", help="use pre-trained model")
return parent_parser
def main(args: Namespace) -> None:
if args.seed is not None:
pl.seed_everything(args.seed)
if args.accelerator == "ddp":
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / max(1, args.gpus))
args.workers = int(args.workers / max(1, args.gpus))
model = ImageNetLightningModel(**vars(args))
trainer = pl.Trainer.from_argparse_args(args)
if args.evaluate:
trainer.test(model)
else:
trainer.fit(model)
def run_cli():
parent_parser = ArgumentParser(add_help=False)
parent_parser = pl.Trainer.add_argparse_args(parent_parser)
parent_parser.add_argument("--data-path", metavar="DIR", type=str, help="path to dataset")
parent_parser.add_argument(
"-e", "--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set"
)
parent_parser.add_argument("--seed", type=int, default=42, help="seed for initializing training.")
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
parser.set_defaults(profiler="simple", deterministic=True, max_epochs=90)
args = parser.parse_args()
main(args)
if __name__ == "__main__":
cli_lightning_logo()
run_cli()