lightning/pl_examples/domain_templates/computer_vision_fine_tuning.py

455 lines
17 KiB
Python

"""Computer vision example on Transfer Learning.
This computer vision example illustrates how one could fine-tune a pre-trained
network (by default, a ResNet50 is used) using pytorch-lightning. For the sake
of this example, the 'cats and dogs dataset' (~60MB, see `DATA_URL` below) and
the proposed network (denoted by `TransferLearningModel`, see below) is
trained for 15 epochs. The training consists in three stages. From epoch 0 to
4, the feature extractor (the pre-trained network) is frozen except maybe for
the BatchNorm layers (depending on whether `train_bn = True`). The BatchNorm
layers (if `train_bn = True`) and the parameters of the classifier are trained
as a single parameters group with lr = 1e-2. From epoch 5 to 9, the last two
layer groups of the pre-trained network are unfrozen and added to the
optimizer as a new parameter group with lr = 1e-4 (while lr = 1e-3 for the
first parameter group in the optimizer). Eventually, from epoch 10, all the
remaining layer groups of the pre-trained network are unfrozen and added to
the optimizer as a third parameter group. From epoch 10, the parameters of the
pre-trained network are trained with lr = 1e-5 while those of the classifier
are trained with lr = 1e-4.
Note:
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
"""
import argparse
from collections import OrderedDict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Generator, Union
from torch.nn import Module
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import _logger as log
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive
BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)
DATA_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
# --- Utility functions ---
def _make_trainable(module: Module) -> None:
"""Unfreezes a given module.
Args:
module: The module to unfreeze
"""
for param in module.parameters():
param.requires_grad = True
module.train()
def _recursive_freeze(module: Module,
train_bn: bool = True) -> None:
"""Freezes the layers of a given module.
Args:
module: The module to freeze
train_bn: If True, leave the BatchNorm layers in training mode
"""
children = list(module.children())
if not children:
if not (isinstance(module, BN_TYPES) and train_bn):
for param in module.parameters():
param.requires_grad = False
module.eval()
else:
# Make the BN layers trainable
_make_trainable(module)
else:
for child in children:
_recursive_freeze(module=child, train_bn=train_bn)
def freeze(module: Module,
n: Optional[int] = None,
train_bn: bool = True) -> None:
"""Freezes the layers up to index n (if n is not None).
Args:
module: The module to freeze (at least partially)
n: Max depth at which we stop freezing the layers. If None, all
the layers of the given module will be frozen.
train_bn: If True, leave the BatchNorm layers in training mode
"""
children = list(module.children())
n_max = len(children) if n is None else int(n)
for child in children[:n_max]:
_recursive_freeze(module=child, train_bn=train_bn)
for child in children[n_max:]:
_make_trainable(module=child)
def filter_params(module: Module,
train_bn: bool = True) -> Generator:
"""Yields the trainable parameters of a given module.
Args:
module: A given module
train_bn: If True, leave the BatchNorm layers in training mode
Returns:
Generator
"""
children = list(module.children())
if not children:
if not (isinstance(module, BN_TYPES) and train_bn):
for param in module.parameters():
if param.requires_grad:
yield param
else:
for child in children:
for param in filter_params(module=child, train_bn=train_bn):
yield param
def _unfreeze_and_add_param_group(module: Module,
optimizer: Optimizer,
lr: Optional[float] = None,
train_bn: bool = True):
"""Unfreezes a module and adds its parameters to an optimizer."""
_make_trainable(module)
params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr)
optimizer.add_param_group(
{'params': filter_params(module=module, train_bn=train_bn),
'lr': params_lr / 10.,
})
# --- Pytorch-lightning module ---
class TransferLearningModel(pl.LightningModule):
"""Transfer Learning with pre-trained ResNet50.
Args:
hparams: Model hyperparameters
dl_path: Path where the data will be downloaded
"""
def __init__(self,
dl_path: Union[str, Path],
backbone: str = 'resnet50',
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 8,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6, **kwargs) -> None:
super().__init__()
self.dl_path = dl_path
self.backbone = backbone
self.train_bn = train_bn
self.milestones = milestones
self.batch_size = batch_size
self.lr = lr
self.lr_scheduler_gamma = lr_scheduler_gamma
self.num_workers = num_workers
self.dl_path = dl_path
self.__build_model()
def __build_model(self):
"""Define model layers & loss."""
# 1. Load pre-trained network:
model_func = getattr(models, self.backbone)
backbone = model_func(pretrained=True)
_layers = list(backbone.children())[:-1]
self.feature_extractor = torch.nn.Sequential(*_layers)
freeze(module=self.feature_extractor, train_bn=self.train_bn)
# 2. Classifier:
_fc_layers = [torch.nn.Linear(2048, 256),
torch.nn.Linear(256, 32),
torch.nn.Linear(32, 1)]
self.fc = torch.nn.Sequential(*_fc_layers)
# 3. Loss:
self.loss_func = F.binary_cross_entropy_with_logits
def forward(self, x):
"""Forward pass. Returns logits."""
# 1. Feature extraction:
x = self.feature_extractor(x)
x = x.squeeze(-1).squeeze(-1)
# 2. Classifier (returns logits):
x = self.fc(x)
return x
def loss(self, logits, labels):
return self.loss_func(input=logits, target=labels)
def train(self, mode=True):
super().train(mode=mode)
epoch = self.current_epoch
if epoch < self.milestones[0] and mode:
# feature extractor is frozen (except for BatchNorm layers)
freeze(module=self.feature_extractor,
train_bn=self.train_bn)
elif self.milestones[0] <= epoch < self.milestones[1] and mode:
# Unfreeze last two layers of the feature extractor
freeze(module=self.feature_extractor,
n=-2,
train_bn=self.train_bn)
def on_epoch_start(self):
"""Use `on_epoch_start` to unfreeze layers progressively."""
optimizer = self.trainer.optimizers[0]
if self.current_epoch == self.milestones[0]:
_unfreeze_and_add_param_group(module=self.feature_extractor[-2:],
optimizer=optimizer,
train_bn=self.train_bn)
elif self.current_epoch == self.milestones[1]:
_unfreeze_and_add_param_group(module=self.feature_extractor[:-2],
optimizer=optimizer,
train_bn=self.train_bn)
def training_step(self, batch, batch_idx):
# 1. Forward pass:
x, y = batch
y_logits = self.forward(x)
y_true = y.view((-1, 1)).type_as(x)
y_bin = torch.ge(y_logits, 0)
# 2. Compute loss & accuracy:
train_loss = self.loss(y_logits, y_true)
num_correct = torch.eq(y_bin.view(-1), y_true.view(-1)).sum()
# 3. Outputs:
tqdm_dict = {'train_loss': train_loss}
output = OrderedDict({'loss': train_loss,
'num_correct': num_correct,
'log': tqdm_dict,
'progress_bar': tqdm_dict})
return output
def training_epoch_end(self, outputs):
"""Compute and log training loss and accuracy at the epoch level."""
train_loss_mean = torch.stack([output['loss']
for output in outputs]).mean()
train_acc_mean = torch.stack([output['num_correct']
for output in outputs]).sum().float()
train_acc_mean /= (len(outputs) * self.batch_size)
return {'log': {'train_loss': train_loss_mean,
'train_acc': train_acc_mean,
'step': self.current_epoch}}
def validation_step(self, batch, batch_idx):
# 1. Forward pass:
x, y = batch
y_logits = self.forward(x)
y_true = y.view((-1, 1)).type_as(x)
y_bin = torch.ge(y_logits, 0)
# 2. Compute loss & accuracy:
val_loss = self.loss(y_logits, y_true)
num_correct = torch.eq(y_bin.view(-1), y_true.view(-1)).sum()
return {'val_loss': val_loss,
'num_correct': num_correct}
def validation_epoch_end(self, outputs):
"""Compute and log validation loss and accuracy at the epoch level."""
val_loss_mean = torch.stack([output['val_loss']
for output in outputs]).mean()
val_acc_mean = torch.stack([output['num_correct']
for output in outputs]).sum().float()
val_acc_mean /= (len(outputs) * self.batch_size)
return {'log': {'val_loss': val_loss_mean,
'val_acc': val_acc_mean,
'step': self.current_epoch}}
def configure_optimizers(self):
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
self.parameters()),
lr=self.lr)
scheduler = MultiStepLR(optimizer,
milestones=self.milestones,
gamma=self.lr_scheduler_gamma)
return [optimizer], [scheduler]
def prepare_data(self):
"""Download images and prepare images datasets."""
download_and_extract_archive(url=DATA_URL,
download_root=self.dl_path,
remove_finished=True)
def setup(self, stage: str):
data_path = Path(self.dl_path).joinpath('cats_and_dogs_filtered')
# 2. Load the data + preprocessing & data augmentation
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = ImageFolder(root=data_path.joinpath('train'),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
valid_dataset = ImageFolder(root=data_path.joinpath('validation'),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]))
self.train_dataset = train_dataset
self.valid_dataset = valid_dataset
def __dataloader(self, train):
"""Train/validation loaders."""
_dataset = self.train_dataset if train else self.valid_dataset
loader = DataLoader(dataset=_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True if train else False)
return loader
def train_dataloader(self):
log.info('Training data loaded.')
return self.__dataloader(train=True)
def val_dataloader(self):
log.info('Validation data loaded.')
return self.__dataloader(train=False)
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser])
parser.add_argument('--backbone',
default='resnet50',
type=str,
metavar='BK',
help='Name (as in ``torchvision.models``) of the feature extractor')
parser.add_argument('--epochs',
default=15,
type=int,
metavar='N',
help='total number of epochs',
dest='nb_epochs')
parser.add_argument('--batch-size',
default=8,
type=int,
metavar='B',
help='batch size',
dest='batch_size')
parser.add_argument('--gpus',
type=int,
default=1,
help='number of gpus to use')
parser.add_argument('--lr',
'--learning-rate',
default=1e-2,
type=float,
metavar='LR',
help='initial learning rate',
dest='lr')
parser.add_argument('--lr-scheduler-gamma',
default=1e-1,
type=float,
metavar='LRG',
help='Factor by which the learning rate is reduced at each milestone',
dest='lr_scheduler_gamma')
parser.add_argument('--num-workers',
default=6,
type=int,
metavar='W',
help='number of CPU workers',
dest='num_workers')
parser.add_argument('--train-bn',
default=True,
type=bool,
metavar='TB',
help='Whether the BatchNorm layers should be trainable',
dest='train_bn')
parser.add_argument('--milestones',
default=[5, 10],
type=list,
metavar='M',
help='List of two epochs milestones')
return parser
def main(args: argparse.Namespace) -> None:
"""Train the model.
Args:
args: Model hyper-parameters
Note:
For the sake of the example, the images dataset will be downloaded
to a temporary directory.
"""
with TemporaryDirectory(dir=args.root_data_path) as tmp_dir:
model = TransferLearningModel(dl_path=tmp_dir, **vars(args))
trainer = pl.Trainer(
weights_summary=None,
show_progress_bar=True,
num_sanity_val_steps=0,
gpus=args.gpus,
min_epochs=args.nb_epochs,
max_epochs=args.nb_epochs)
trainer.fit(model)
def get_args() -> argparse.Namespace:
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument('--root-data-path',
metavar='DIR',
type=str,
default=Path.cwd().as_posix(),
help='Root directory where to download the data',
dest='root_data_path')
parser = TransferLearningModel.add_model_specific_args(parent_parser)
return parser.parse_args()
if __name__ == '__main__':
main(get_args())