From fafe5d63a70a422fb8c8892c6f0a10c2c6f23816 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste SCHIRATTI Date: Sat, 2 May 2020 15:08:46 +0200 Subject: [PATCH] Transfer learning example (#1564) * Fine tuning example. * Fix (in train method) + Borda's comments (added argparse + fixed docstrings). * Updated CHANGELOG.md * Fix + updated docstring. * Fixes (awaelchli's comments) + docstrings. * Fix train/val loss. * Fix. --- CHANGELOG.md | 2 + .../computer_vision_fine_tuning.py | 440 ++++++++++++++++++ 2 files changed, 442 insertions(+) create mode 100644 pl_examples/domain_templates/computer_vision_fine_tuning.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f67ba54900..ef025c6a85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498)) +- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564)) + ### Changed ### Deprecated diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py new file mode 100644 index 0000000000..42a0a936d9 --- /dev/null +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -0,0 +1,440 @@ +"""Computer vision example on Transfer Learning. + +This computer vision example illustrates how one could fine-tune a pre-trained +network (by default, a ResNet50 is used) using pytorch-lightning. For the sake +of this example, the 'cats and dogs dataset' (~60MB, see `DATA_URL` below) and +the proposed network (denoted by `TransferLearningModel`, see below) is +trained for 15 epochs. The training consists in three stages. From epoch 0 to +4, the feature extractor (the pre-trained network) is frozen except maybe for +the BatchNorm layers (depending on whether `train_bn = True`). The BatchNorm +layers (if `train_bn = True`) and the parameters of the classifier are trained +as a single parameters group with lr = 1e-2. From epoch 5 to 9, the last two +layer groups of the pre-trained network are unfrozen and added to the +optimizer as a new parameter group with lr = 1e-4 (while lr = 1e-3 for the +first parameter group in the optimizer). Eventually, from epoch 10, all the +remaining layer groups of the pre-trained network are unfrozen and added to +the optimizer as a third parameter group. From epoch 10, the parameters of the +pre-trained network are trained with lr = 1e-5 while those of the classifier +are trained with lr = 1e-4. + +Note: + See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html +""" + +import argparse +from collections import OrderedDict +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Generator, Union + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from pytorch_lightning import _logger as log +from torch import optim +from torch.optim.lr_scheduler import MultiStepLR +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader +from torchvision import models +from torchvision import transforms +from torchvision.datasets import ImageFolder +from torchvision.datasets.utils import download_and_extract_archive + +BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d) +DATA_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip' + + +# --- Utility functions --- + + +def _make_trainable(module: torch.nn.Module) -> None: + """Unfreezes a given module. + + Args: + module: The module to unfreeze + """ + for param in module.parameters(): + param.requires_grad = True + module.train() + + +def _recursive_freeze(module: torch.nn.Module, + train_bn: bool = True) -> None: + """Freezes the layers of a given module. + + Args: + module: The module to freeze + train_bn: If True, leave the BatchNorm layers in training mode + """ + children = list(module.children()) + if not children: + if not (isinstance(module, BN_TYPES) and train_bn): + for param in module.parameters(): + param.requires_grad = False + module.eval() + else: + # Make the BN layers trainable + _make_trainable(module) + else: + for child in children: + _recursive_freeze(module=child, train_bn=train_bn) + + +def freeze(module: torch.nn.Module, + n: Optional[int] = None, + train_bn: bool = True) -> None: + """Freezes the layers up to index n (if n is not None). + + Args: + module: The module to freeze (at least partially) + n: Max depth at which we stop freezing the layers. If None, all + the layers of the given module will be frozen. + train_bn: If True, leave the BatchNorm layers in training mode + """ + children = list(module.children()) + n_max = len(children) if n is None else int(n) + + for child in children[:n_max]: + _recursive_freeze(module=child, train_bn=train_bn) + + for child in children[n_max:]: + _make_trainable(module=child) + + +def filter_params(module: torch.nn.Module, + train_bn: bool = True) -> Generator: + """Yields the trainable parameters of a given module. + + Args: + module: A given module + train_bn: If True, leave the BatchNorm layers in training mode + + Returns: + Generator + """ + children = list(module.children()) + if not children: + if not (isinstance(module, BN_TYPES) and train_bn): + for param in module.parameters(): + if param.requires_grad: + yield param + else: + for child in children: + for param in filter_params(module=child, train_bn=train_bn): + yield param + + +def _unfreeze_and_add_param_group(module: torch.nn.Module, + optimizer: Optimizer, + lr: Optional[float] = None, + train_bn: bool = True): + """Unfreezes a module and adds its parameters to an optimizer.""" + _make_trainable(module) + params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr) + optimizer.add_param_group( + {'params': filter_params(module=module, train_bn=train_bn), + 'lr': params_lr / 10., + }) + + +# --- Pytorch-lightning module --- + + +class TransferLearningModel(pl.LightningModule): + """Transfer Learning with pre-trained ResNet50. + + Args: + hparams: Model hyperparameters + dl_path: Path where the data will be downloaded + """ + def __init__(self, + hparams: argparse.Namespace, + dl_path: Union[str, Path]) -> None: + super().__init__() + self.hparams = hparams + self.dl_path = dl_path + self.__build_model() + + def __build_model(self): + """Define model layers & loss.""" + + # 1. Load pre-trained network: + model_func = getattr(models, self.hparams.backbone) + backbone = model_func(pretrained=True) + + _layers = list(backbone.children())[:-1] + self.feature_extractor = torch.nn.Sequential(*_layers) + freeze(module=self.feature_extractor, train_bn=self.hparams.train_bn) + + # 2. Classifier: + _fc_layers = [torch.nn.Linear(2048, 256), + torch.nn.Linear(256, 32), + torch.nn.Linear(32, 1)] + self.fc = torch.nn.Sequential(*_fc_layers) + + # 3. Loss: + self.loss_func = F.binary_cross_entropy_with_logits + + def forward(self, x): + """Forward pass. Returns logits.""" + + # 1. Feature extraction: + x = self.feature_extractor(x) + x = x.squeeze(-1).squeeze(-1) + + # 2. Classifier (returns logits): + x = self.fc(x) + + return x + + def loss(self, labels, logits): + return self.loss_func(input=logits, target=labels) + + def train(self, mode=True): + super().train(mode=mode) + + epoch = self.current_epoch + if epoch < self.hparams.milestones[0] and mode: + # feature extractor is frozen (except for BatchNorm layers) + freeze(module=self.feature_extractor, + train_bn=self.hparams.train_bn) + + elif self.hparams.milestones[0] <= epoch < self.hparams.milestones[1] and mode: + # Unfreeze last two layers of the feature extractor + freeze(module=self.feature_extractor, + n=-2, + train_bn=self.hparams.train_bn) + + def on_epoch_start(self): + """Use `on_epoch_start` to unfreeze layers progressively.""" + optimizer = self.trainer.optimizers[0] + if self.current_epoch == self.hparams.milestones[0]: + _unfreeze_and_add_param_group(module=self.feature_extractor[-2:], + optimizer=optimizer, + train_bn=self.hparams.train_bn) + + elif self.current_epoch == self.hparams.milestones[1]: + _unfreeze_and_add_param_group(module=self.feature_extractor[:-2], + optimizer=optimizer, + train_bn=self.hparams.train_bn) + + def training_step(self, batch, batch_idx): + + # 1. Forward pass: + x, y = batch + y_logits = self.forward(x) + y_true = y.view((-1, 1)).type_as(x) + y_bin = torch.ge(y_logits, 0) + + # 2. Compute loss & accuracy: + train_loss = self.loss(y_true, y_logits) + num_correct = torch.eq(y_bin.view(-1), y_true.view(-1)).sum() + + # 3. Outputs: + tqdm_dict = {'train_loss': train_loss} + output = OrderedDict({'loss': train_loss, + 'num_correct': num_correct, + 'log': tqdm_dict, + 'progress_bar': tqdm_dict}) + + return output + + def training_epoch_end(self, outputs): + """Compute and log training loss and accuracy at the epoch level.""" + + train_loss_mean = torch.stack([output['loss'] + for output in outputs]).mean() + train_acc_mean = torch.stack([output['num_correct'] + for output in outputs]).sum().float() + train_acc_mean /= (len(outputs) * self.hparams.batch_size) + return {'log': {'train_loss': train_loss_mean, + 'train_acc': train_acc_mean, + 'step': self.current_epoch}} + + def validation_step(self, batch, batch_idx): + + # 1. Forward pass: + x, y = batch + y_logits = self.forward(x) + y_true = y.view((-1, 1)).type_as(x) + y_bin = torch.ge(y_logits, 0) + + # 2. Compute loss & accuracy: + val_loss = self.loss(y_true, y_logits) + num_correct = torch.eq(y_bin.view(-1), y_true.view(-1)).sum() + + return {'val_loss': val_loss, + 'num_correct': num_correct} + + def validation_epoch_end(self, outputs): + """Compute and log validation loss and accuracy at the epoch level.""" + + val_loss_mean = torch.stack([output['val_loss'] + for output in outputs]).mean() + val_acc_mean = torch.stack([output['num_correct'] + for output in outputs]).sum().float() + val_acc_mean /= (len(outputs) * self.hparams.batch_size) + return {'log': {'val_loss': val_loss_mean, + 'val_acc': val_acc_mean, + 'step': self.current_epoch}} + + def configure_optimizers(self): + optimizer = optim.Adam(filter(lambda p: p.requires_grad, + self.parameters()), + lr=self.hparams.lr) + + scheduler = MultiStepLR(optimizer, + milestones=self.hparams.milestones, + gamma=self.hparams.lr_scheduler_gamma) + + return [optimizer], [scheduler] + + def prepare_data(self): + """Download images and prepare images datasets.""" + + # 1. Download the images + download_and_extract_archive(url=DATA_URL, + download_root=self.dl_path, + remove_finished=True) + + data_path = Path(self.dl_path).joinpath('cats_and_dogs_filtered') + + # 2. Load the data + preprocessing & data augmentation + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = ImageFolder(root=data_path.joinpath('train'), + transform=transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + valid_dataset = ImageFolder(root=data_path.joinpath('validation'), + transform=transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + normalize, + ])) + + self.train_dataset = train_dataset + self.valid_dataset = valid_dataset + + def __dataloader(self, train): + """Train/validation loaders.""" + + _dataset = self.train_dataset if train else self.valid_dataset + loader = DataLoader(dataset=_dataset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + shuffle=True if train else False) + + return loader + + def train_dataloader(self): + log.info('Training data loaded.') + return self.__dataloader(train=True) + + def val_dataloader(self): + log.info('Validation data loaded.') + return self.__dataloader(train=False) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = argparse.ArgumentParser(parents=[parent_parser]) + parser.add_argument('--backbone', + default='resnet50', + type=str, + metavar='BK', + help='Name (as in ``torchvision.models``) of the feature extractor') + parser.add_argument('--epochs', + default=15, + type=int, + metavar='N', + help='total number of epochs', + dest='nb_epochs') + parser.add_argument('--batch-size', + default=8, + type=int, + metavar='B', + help='batch size', + dest='batch_size') + parser.add_argument('--gpus', + type=int, + default=1, + help='number of gpus to use') + parser.add_argument('--lr', + '--learning-rate', + default=1e-2, + type=float, + metavar='LR', + help='initial learning rate', + dest='lr') + parser.add_argument('--lr-scheduler-gamma', + default=1e-1, + type=float, + metavar='LRG', + help='Factor by which the learning rate is reduced at each milestone', + dest='lr_scheduler_gamma') + parser.add_argument('--num-workers', + default=6, + type=int, + metavar='W', + help='number of CPU workers', + dest='num_workers') + parser.add_argument('--train-bn', + default=True, + type=bool, + metavar='TB', + help='Whether the BatchNorm layers should be trainable', + dest='train_bn') + parser.add_argument('--milestones', + default=[5, 10], + type=list, + metavar='M', + help='List of two epochs milestones') + return parser + + +def main(hparams: argparse.Namespace) -> None: + """Train the model. + + Args: + hparams: Model hyper-parameters + + Note: + For the sake of the example, the images dataset will be downloaded + to a temporary directory. + """ + + with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir: + + model = TransferLearningModel(hparams, dl_path=tmp_dir) + + trainer = pl.Trainer( + weights_summary=None, + show_progress_bar=True, + num_sanity_val_steps=0, + gpus=hparams.gpus, + min_epochs=hparams.nb_epochs, + max_epochs=hparams.nb_epochs) + + trainer.fit(model) + + +def get_args() -> argparse.Namespace: + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser.add_argument('--root-data-path', + metavar='DIR', + type=str, + default=Path.cwd().as_posix(), + help='Root directory where to download the data', + dest='root_data_path') + parser = TransferLearningModel.add_model_specific_args(parent_parser) + return parser.parse_args() + + +if __name__ == '__main__': + + main(get_args())