From d56750899fd9526ee2fa765ed79e3250f2ae2ecc Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Sat, 9 Nov 2019 01:02:21 -0500 Subject: [PATCH] ImageNet Example (#476) * ImageNet example * cleanup * cleanup * Minor changes from feedback * More cleanup --- .../imagenet/imagenet_example.py | 243 ++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 pl_examples/full_examples/imagenet/imagenet_example.py diff --git a/pl_examples/full_examples/imagenet/imagenet_example.py b/pl_examples/full_examples/imagenet/imagenet_example.py new file mode 100644 index 0000000000..3e45abd073 --- /dev/null +++ b/pl_examples/full_examples/imagenet/imagenet_example.py @@ -0,0 +1,243 @@ +""" +This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py +""" +import argparse +import os +import random +from collections import OrderedDict + +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.nn.functional as F +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +import torch.utils.data +import torch.utils.data.distributed + +import torchvision.transforms as transforms +import torchvision.models as models +import torchvision.datasets as datasets + +import pytorch_lightning as pl + + +MODEL_NAMES = sorted( + name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name]) +) +NORMALIZE = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], +) + + +class ImageNetLightningModel(pl.LightningModule): + + def __init__(self, hparams): + super(ImageNetLightningModel, self).__init__() + self.hparams = hparams + self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained) + + def training_step(self, batch, batch_idx): + images, target = batch + output = self.model(images) + loss_val = F.cross_entropy(output, target) + acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp or self.trainer.use_ddp2: + loss_val = loss_val.unsqueeze(0) + acc1 = acc1.unsqueeze(0) + acc5 = acc5.unsqueeze(0) + + tqdm_dict = {'train_loss': loss_val} + output = OrderedDict({ + 'loss': loss_val, + 'acc1': acc1, + 'acc5': acc5, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict + }) + + return output + + def validation_step(self, batch, batch_idx): + images, target = batch + output = self.model(images) + loss_val = F.cross_entropy(output, target) + acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp or self.trainer.use_ddp2: + loss_val = loss_val.unsqueeze(0) + acc1 = acc1.unsqueeze(0) + acc5 = acc5.unsqueeze(0) + + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc1': acc1, + 'val_acc5': acc5, + }) + + return output + + def validation_end(self, outputs): + + tqdm_dict = {} + + for metric_name in ["val_loss", "val_acc1", "val_acc5"]: + metric_total = 0 + + for output in outputs: + metric_value = output[metric_name] + + # reduce manually when using dp + if self.trainer.use_dp or self.trainer.use_ddp2: + metric_value = torch.mean(metric_value) + + metric_total += metric_value + + tqdm_dict[metric_name] = metric_total / len(outputs) + + result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]} + return result + + @classmethod + def __accuracy(cls, output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + def configure_optimizers(self): + optimizer = optim.SGD( + self.parameters(), + lr=self.hparams.lr, + momentum=self.hparams.momentum, + weight_decay=self.hparams.weight_decay + ) + scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1) + return [optimizer], [scheduler] + + @pl.data_loader + def train_dataloader(self): + train_dir = os.path.join(self.hparams.data, 'train') + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + NORMALIZE, + ])) + + if self.use_ddp: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=self.hparams.batch_size, + shuffle=(train_sampler is None), + num_workers=0, + sampler=train_sampler + ) + return train_loader + + @pl.data_loader + def val_dataloader(self): + val_dir = os.path.join(self.hparams.data, 'val') + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(val_dir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + NORMALIZE, + ])), + batch_size=self.hparams.batch_size, + shuffle=False, + num_workers=0, + ) + return val_loader + + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no cover + parser = argparse.ArgumentParser(parents=[parent_parser]) + parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES, + help='model architecture: ' + + ' | '.join(MODEL_NAMES) + + ' (default: resnet18)') + parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('--seed', type=int, default=None, + help='seed for initializing training. ') + parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') + parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') + return parser + + +def get_args(): + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser.add_argument('--data-path', metavar='DIR', type=str, + help='path to dataset') + parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str, + help='path to save output') + parent_parser.add_argument('--gpus', type=int, default=1, + help='how many gpus') + parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), + help='supports three options dp, ddp, ddp2') + parent_parser.add_argument('--use-16bit', dest='use-16bit', action='store_true', + help='if true uses 16 bit precision') + parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') + + parser = ImageNetLightningModel.add_model_specific_args(parent_parser) + return parser.parse_args() + + +def main(hparams): + model = ImageNetLightningModel(hparams) + if hparams.seed is not None: + random.seed(hparams.seed) + torch.manual_seed(hparams.seed) + cudnn.deterministic = True + trainer = pl.Trainer( + default_save_path=hparams.save_path, + gpus=hparams.gpus, + max_nb_epochs=hparams.epochs, + distributed_backend=hparams.distributed_backend, + use_amp=hparams.use_16bit + ) + if hparams.evaluate: + trainer.run_evaluation() + else: + trainer.fit(model) + + +if __name__ == '__main__': + main(get_args())