2019-11-09 06:02:21 +00:00
|
|
|
"""
|
|
|
|
This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
|
|
|
"""
|
2020-06-01 15:38:52 +00:00
|
|
|
from argparse import ArgumentParser, Namespace
|
2019-11-09 06:02:21 +00:00
|
|
|
import os
|
|
|
|
import random
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.nn.functional as F
|
2020-01-20 19:50:31 +00:00
|
|
|
import torch.nn.parallel
|
2019-11-09 06:02:21 +00:00
|
|
|
import torch.optim as optim
|
|
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
|
|
import torch.utils.data
|
|
|
|
import torch.utils.data.distributed
|
|
|
|
import torchvision.datasets as datasets
|
2020-01-20 19:50:31 +00:00
|
|
|
import torchvision.models as models
|
|
|
|
import torchvision.transforms as transforms
|
2019-11-09 06:02:21 +00:00
|
|
|
|
|
|
|
import pytorch_lightning as pl
|
2020-02-27 21:07:51 +00:00
|
|
|
from pytorch_lightning.core import LightningModule
|
2019-11-09 06:02:21 +00:00
|
|
|
|
2019-11-09 12:15:07 +00:00
|
|
|
# pull out resnet names from torchvision models
|
2019-11-09 06:02:21 +00:00
|
|
|
MODEL_NAMES = sorted(
|
|
|
|
name for name in models.__dict__
|
2019-11-09 12:15:07 +00:00
|
|
|
if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
|
2019-11-09 06:02:21 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-02-27 21:07:51 +00:00
|
|
|
class ImageNetLightningModel(LightningModule):
|
2020-05-24 22:59:08 +00:00
|
|
|
def __init__(self,
|
|
|
|
arch,
|
|
|
|
pretrained,
|
|
|
|
lr: float,
|
|
|
|
momentum: float,
|
|
|
|
weight_decay: int,
|
|
|
|
data_path: str,
|
|
|
|
batch_size: int, **kwargs):
|
2020-02-27 21:07:51 +00:00
|
|
|
"""
|
|
|
|
TODO: add docstring here
|
|
|
|
"""
|
2020-03-27 12:36:50 +00:00
|
|
|
super().__init__()
|
2020-05-24 22:59:08 +00:00
|
|
|
self.arch = arch
|
|
|
|
self.pretrained = pretrained
|
|
|
|
self.lr = lr
|
|
|
|
self.momentum = momentum
|
|
|
|
self.weight_decay = weight_decay
|
|
|
|
self.data_path = data_path
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.model = models.__dict__[self.arch](pretrained=self.pretrained)
|
2019-11-09 06:02:21 +00:00
|
|
|
|
2020-01-21 21:35:42 +00:00
|
|
|
def forward(self, x):
|
|
|
|
return self.model(x)
|
|
|
|
|
2019-11-09 06:02:21 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
images, target = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
output = self(images)
|
2019-11-09 06:02:21 +00:00
|
|
|
loss_val = F.cross_entropy(output, target)
|
|
|
|
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
|
|
|
|
|
|
|
tqdm_dict = {'train_loss': loss_val}
|
|
|
|
output = OrderedDict({
|
|
|
|
'loss': loss_val,
|
|
|
|
'acc1': acc1,
|
|
|
|
'acc5': acc5,
|
|
|
|
'progress_bar': tqdm_dict,
|
|
|
|
'log': tqdm_dict
|
|
|
|
})
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
images, target = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
output = self(images)
|
2019-11-09 06:02:21 +00:00
|
|
|
loss_val = F.cross_entropy(output, target)
|
|
|
|
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
|
|
|
|
|
|
|
output = OrderedDict({
|
|
|
|
'val_loss': loss_val,
|
|
|
|
'val_acc1': acc1,
|
|
|
|
'val_acc5': acc5,
|
|
|
|
})
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
2020-03-06 00:31:57 +00:00
|
|
|
def validation_epoch_end(self, outputs):
|
2019-11-09 06:02:21 +00:00
|
|
|
|
|
|
|
tqdm_dict = {}
|
|
|
|
|
|
|
|
for metric_name in ["val_loss", "val_acc1", "val_acc5"]:
|
|
|
|
metric_total = 0
|
|
|
|
|
|
|
|
for output in outputs:
|
|
|
|
metric_value = output[metric_name]
|
|
|
|
|
|
|
|
# reduce manually when using dp
|
|
|
|
if self.trainer.use_dp or self.trainer.use_ddp2:
|
|
|
|
metric_value = torch.mean(metric_value)
|
|
|
|
|
|
|
|
metric_total += metric_value
|
|
|
|
|
|
|
|
tqdm_dict[metric_name] = metric_total / len(outputs)
|
|
|
|
|
|
|
|
result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]}
|
|
|
|
return result
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def __accuracy(cls, output, target, topk=(1,)):
|
|
|
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
|
|
with torch.no_grad():
|
|
|
|
maxk = max(topk)
|
|
|
|
batch_size = target.size(0)
|
|
|
|
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
|
|
pred = pred.t()
|
|
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
|
|
|
|
res = []
|
|
|
|
for k in topk:
|
|
|
|
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
|
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
|
|
return res
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = optim.SGD(
|
|
|
|
self.parameters(),
|
2020-05-24 22:59:08 +00:00
|
|
|
lr=self.lr,
|
|
|
|
momentum=self.momentum,
|
|
|
|
weight_decay=self.weight_decay
|
2019-11-09 06:02:21 +00:00
|
|
|
)
|
|
|
|
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
|
|
|
|
return [optimizer], [scheduler]
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
2019-11-09 12:15:07 +00:00
|
|
|
normalize = transforms.Normalize(
|
|
|
|
mean=[0.485, 0.456, 0.406],
|
|
|
|
std=[0.229, 0.224, 0.225],
|
|
|
|
)
|
|
|
|
|
2020-05-24 22:59:08 +00:00
|
|
|
train_dir = os.path.join(self.data_path, 'train')
|
2019-11-09 06:02:21 +00:00
|
|
|
train_dataset = datasets.ImageFolder(
|
|
|
|
train_dir,
|
|
|
|
transforms.Compose([
|
|
|
|
transforms.RandomResizedCrop(224),
|
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.ToTensor(),
|
2019-11-09 12:15:07 +00:00
|
|
|
normalize,
|
2019-11-09 06:02:21 +00:00
|
|
|
]))
|
|
|
|
|
|
|
|
if self.use_ddp:
|
|
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
|
|
|
else:
|
|
|
|
train_sampler = None
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
|
|
dataset=train_dataset,
|
2020-05-24 22:59:08 +00:00
|
|
|
batch_size=self.batch_size,
|
2019-11-09 06:02:21 +00:00
|
|
|
shuffle=(train_sampler is None),
|
|
|
|
num_workers=0,
|
|
|
|
sampler=train_sampler
|
|
|
|
)
|
|
|
|
return train_loader
|
|
|
|
|
|
|
|
def val_dataloader(self):
|
2019-11-09 12:15:07 +00:00
|
|
|
normalize = transforms.Normalize(
|
|
|
|
mean=[0.485, 0.456, 0.406],
|
|
|
|
std=[0.229, 0.224, 0.225],
|
|
|
|
)
|
2020-05-24 22:59:08 +00:00
|
|
|
val_dir = os.path.join(self.data_path, 'val')
|
2019-11-09 06:02:21 +00:00
|
|
|
val_loader = torch.utils.data.DataLoader(
|
|
|
|
datasets.ImageFolder(val_dir, transforms.Compose([
|
|
|
|
transforms.Resize(256),
|
|
|
|
transforms.CenterCrop(224),
|
|
|
|
transforms.ToTensor(),
|
2019-11-09 12:15:07 +00:00
|
|
|
normalize,
|
2019-11-09 06:02:21 +00:00
|
|
|
])),
|
2020-05-24 22:59:08 +00:00
|
|
|
batch_size=self.batch_size,
|
2019-11-09 06:02:21 +00:00
|
|
|
shuffle=False,
|
|
|
|
num_workers=0,
|
|
|
|
)
|
|
|
|
return val_loader
|
|
|
|
|
|
|
|
@staticmethod
|
2020-03-19 13:14:29 +00:00
|
|
|
def add_model_specific_args(parent_parser): # pragma: no-cover
|
2020-06-01 15:38:52 +00:00
|
|
|
parser = ArgumentParser(parents=[parent_parser])
|
2019-11-09 06:02:21 +00:00
|
|
|
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES,
|
|
|
|
help='model architecture: ' +
|
|
|
|
' | '.join(MODEL_NAMES) +
|
|
|
|
' (default: resnet18)')
|
|
|
|
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
|
|
|
help='number of total epochs to run')
|
2020-01-21 21:35:42 +00:00
|
|
|
parser.add_argument('--seed', type=int, default=42,
|
2019-11-09 06:02:21 +00:00
|
|
|
help='seed for initializing training. ')
|
|
|
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|
|
|
metavar='N',
|
|
|
|
help='mini-batch size (default: 256), this is the total '
|
|
|
|
'batch size of all GPUs on the current node when '
|
|
|
|
'using Data Parallel or Distributed Data Parallel')
|
|
|
|
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
|
|
|
metavar='LR', help='initial learning rate', dest='lr')
|
|
|
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
|
|
|
help='momentum')
|
|
|
|
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
|
|
|
metavar='W', help='weight decay (default: 1e-4)',
|
|
|
|
dest='weight_decay')
|
|
|
|
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
|
|
|
help='use pre-trained model')
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
def get_args():
|
2020-06-01 15:38:52 +00:00
|
|
|
parent_parser = ArgumentParser(add_help=False)
|
2019-11-09 06:02:21 +00:00
|
|
|
parent_parser.add_argument('--data-path', metavar='DIR', type=str,
|
|
|
|
help='path to dataset')
|
|
|
|
parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str,
|
|
|
|
help='path to save output')
|
|
|
|
parent_parser.add_argument('--gpus', type=int, default=1,
|
|
|
|
help='how many gpus')
|
|
|
|
parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
|
|
|
|
help='supports three options dp, ddp, ddp2')
|
2020-01-21 21:35:42 +00:00
|
|
|
parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true',
|
2019-11-09 06:02:21 +00:00
|
|
|
help='if true uses 16 bit precision')
|
|
|
|
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
|
|
|
help='evaluate model on validation set')
|
|
|
|
|
|
|
|
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
2020-06-01 15:38:52 +00:00
|
|
|
def main(args: Namespace) -> None:
|
|
|
|
model = ImageNetLightningModel(**vars(args))
|
|
|
|
|
|
|
|
if args.seed is not None:
|
|
|
|
random.seed(args.seed)
|
|
|
|
torch.manual_seed(args.seed)
|
2019-11-09 06:02:21 +00:00
|
|
|
cudnn.deterministic = True
|
2020-06-01 15:38:52 +00:00
|
|
|
|
2019-11-09 06:02:21 +00:00
|
|
|
trainer = pl.Trainer(
|
2020-06-01 15:38:52 +00:00
|
|
|
default_root_dir=args.save_path,
|
|
|
|
gpus=args.gpus,
|
|
|
|
max_epochs=args.epochs,
|
|
|
|
distributed_backend=args.distributed_backend,
|
|
|
|
precision=16 if args.use_16bit else 32,
|
2019-11-09 06:02:21 +00:00
|
|
|
)
|
2020-06-01 15:38:52 +00:00
|
|
|
|
|
|
|
if args.evaluate:
|
2019-11-09 06:02:21 +00:00
|
|
|
trainer.run_evaluation()
|
|
|
|
else:
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main(get_args())
|