""" This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py """ import argparse import os import random from collections import OrderedDict import torch import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch.nn.parallel import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data import torch.utils.data.distributed import torchvision.datasets as datasets import torchvision.models as models import torchvision.transforms as transforms import pytorch_lightning as pl from pytorch_lightning.core import LightningModule # pull out resnet names from torchvision models MODEL_NAMES = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) class ImageNetLightningModel(LightningModule): def __init__(self, hparams): """ TODO: add docstring here """ super().__init__() self.hparams = hparams self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained) def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): images, target = batch output = self(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) tqdm_dict = {'train_loss': loss_val} output = OrderedDict({ 'loss': loss_val, 'acc1': acc1, 'acc5': acc5, 'progress_bar': tqdm_dict, 'log': tqdm_dict }) return output def validation_step(self, batch, batch_idx): images, target = batch output = self(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) output = OrderedDict({ 'val_loss': loss_val, 'val_acc1': acc1, 'val_acc5': acc5, }) return output def validation_epoch_end(self, outputs): tqdm_dict = {} for metric_name in ["val_loss", "val_acc1", "val_acc5"]: metric_total = 0 for output in outputs: metric_value = output[metric_name] # reduce manually when using dp if self.trainer.use_dp or self.trainer.use_ddp2: metric_value = torch.mean(metric_value) metric_total += metric_value tqdm_dict[metric_name] = metric_total / len(outputs) result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]} return result @classmethod def __accuracy(cls, output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def configure_optimizers(self): optimizer = optim.SGD( self.parameters(), lr=self.hparams.lr, momentum=self.hparams.momentum, weight_decay=self.hparams.weight_decay ) scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1) return [optimizer], [scheduler] def train_dataloader(self): normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) train_dir = os.path.join(self.hparams.data_path, 'train') train_dataset = datasets.ImageFolder( train_dir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if self.use_ddp: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=self.hparams.batch_size, shuffle=(train_sampler is None), num_workers=0, sampler=train_sampler ) return train_loader def val_dataloader(self): normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) val_dir = os.path.join(self.hparams.data_path, 'val') val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(val_dir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=self.hparams.batch_size, shuffle=False, num_workers=0, ) return val_loader @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover parser = argparse.ArgumentParser(parents=[parent_parser]) parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES, help='model architecture: ' + ' | '.join(MODEL_NAMES) + ' (default: resnet18)') parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--seed', type=int, default=42, help='seed for initializing training. ') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256), this is the total ' 'batch size of all GPUs on the current node when ' 'using Data Parallel or Distributed Data Parallel') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') return parser def get_args(): parent_parser = argparse.ArgumentParser(add_help=False) parent_parser.add_argument('--data-path', metavar='DIR', type=str, help='path to dataset') parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str, help='path to save output') parent_parser.add_argument('--gpus', type=int, default=1, help='how many gpus') parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), help='supports three options dp, ddp, ddp2') parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true', help='if true uses 16 bit precision') parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser = ImageNetLightningModel.add_model_specific_args(parent_parser) return parser.parse_args() def main(hparams): model = ImageNetLightningModel(hparams) if hparams.seed is not None: random.seed(hparams.seed) torch.manual_seed(hparams.seed) cudnn.deterministic = True trainer = pl.Trainer( default_save_path=hparams.save_path, gpus=hparams.gpus, max_epochs=hparams.epochs, distributed_backend=hparams.distributed_backend, precision=16 if hparams.use_16bit else 32, ) if hparams.evaluate: trainer.run_evaluation() else: trainer.fit(model) if __name__ == '__main__': main(get_args())