# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see below) is trained for 15 epochs. The training consists of three stages. From epoch 0 to 4, the feature extractor (the pre-trained network) is frozen except maybe for the BatchNorm layers (depending on whether `train_bn = True`). The BatchNorm layers (if `train_bn = True`) and the parameters of the classifier are trained as a single parameters group with lr = 1e-2. From epoch 5 to 9, the last two layer groups of the pre-trained network are unfrozen and added to the optimizer as a new parameter group with lr = 1e-4 (while lr = 1e-3 for the first parameter group in the optimizer). Eventually, from epoch 10, all the remaining layer groups of the pre-trained network are unfrozen and added to the optimizer as a third parameter group. From epoch 10, the parameters of the pre-trained network are trained with lr = 1e-5 while those of the classifier is trained with lr = 1e-4. Note: See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html To run: python computer_vision_fine_tuning.py fit """ import logging from pathlib import Path from typing import Union import torch import torch.nn.functional as F from torch import nn, optim from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torchmetrics import Accuracy from torchvision import models, transforms from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.rank_zero import rank_zero_info log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" # --- Finetuning Callback --- class MilestonesFinetuning(BaseFinetuning): def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False): super().__init__() self.milestones = milestones self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule): self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn) def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): if epoch == self.milestones[0]: # unfreeze 5 last layers self.unfreeze_and_add_param_group( modules=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn ) elif epoch == self.milestones[1]: # unfreeze remaining layers self.unfreeze_and_add_param_group( modules=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn ) class CatDogImageDataModule(LightningDataModule): def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8): """CatDogImageDataModule. Args: dl_path: root directory where to download the data num_workers: number of CPU workers batch_size: number of sample in a batch """ super().__init__() self._dl_path = dl_path self._num_workers = num_workers self._batch_size = batch_size def prepare_data(self): """Download images and prepare images datasets.""" download_and_extract_archive(url=DATA_URL, download_root=self._dl_path, remove_finished=True) @property def data_path(self): return Path(self._dl_path).joinpath("cats_and_dogs_filtered") @property def normalize_transform(self): return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @property def train_transform(self): return transforms.Compose( [ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), self.normalize_transform, ] ) @property def valid_transform(self): return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), self.normalize_transform]) def create_dataset(self, root, transform): return ImageFolder(root=root, transform=transform) def __dataloader(self, train: bool): """Train/validation loaders.""" if train: dataset = self.create_dataset(self.data_path.joinpath("train"), self.train_transform) else: dataset = self.create_dataset(self.data_path.joinpath("validation"), self.valid_transform) return DataLoader(dataset=dataset, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=train) def train_dataloader(self): log.info("Training data loaded.") return self.__dataloader(train=True) def val_dataloader(self): log.info("Validation data loaded.") return self.__dataloader(train=False) # --- Pytorch-lightning module --- class TransferLearningModel(pl.LightningModule): def __init__( self, backbone: str = "resnet50", train_bn: bool = False, milestones: tuple = (2, 4), batch_size: int = 32, lr: float = 1e-3, lr_scheduler_gamma: float = 1e-1, num_workers: int = 6, **kwargs, ) -> None: """TransferLearningModel. Args: backbone: Name (as in ``torchvision.models``) of the feature extractor train_bn: Whether the BatchNorm layers should be trainable milestones: List of two epochs milestones lr: Initial learning rate lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone """ super().__init__() self.backbone = backbone self.train_bn = train_bn self.milestones = milestones self.batch_size = batch_size self.lr = lr self.lr_scheduler_gamma = lr_scheduler_gamma self.num_workers = num_workers self.__build_model() self.train_acc = Accuracy() self.valid_acc = Accuracy() self.save_hyperparameters() def __build_model(self): """Define model layers & loss.""" # 1. Load pre-trained network: model_func = getattr(models, self.backbone) backbone = model_func(pretrained=True) _layers = list(backbone.children())[:-1] self.feature_extractor = nn.Sequential(*_layers) # 2. Classifier: _fc_layers = [nn.Linear(2048, 256), nn.ReLU(), nn.Linear(256, 32), nn.Linear(32, 1)] self.fc = nn.Sequential(*_fc_layers) # 3. Loss: self.loss_func = F.binary_cross_entropy_with_logits def forward(self, x): """Forward pass. Returns logits. """ # 1. Feature extraction: x = self.feature_extractor(x) x = x.squeeze(-1).squeeze(-1) # 2. Classifier (returns logits): x = self.fc(x) return x def loss(self, logits, labels): return self.loss_func(input=logits, target=labels) def training_step(self, batch, batch_idx): # 1. Forward pass: x, y = batch y_logits = self.forward(x) y_scores = torch.sigmoid(y_logits) y_true = y.view((-1, 1)).type_as(x) # 2. Compute loss train_loss = self.loss(y_logits, y_true) # 3. Compute accuracy: self.log("train_acc", self.train_acc(y_scores, y_true.int()), prog_bar=True) return train_loss def validation_step(self, batch, batch_idx): # 1. Forward pass: x, y = batch y_logits = self.forward(x) y_scores = torch.sigmoid(y_logits) y_true = y.view((-1, 1)).type_as(x) # 2. Compute loss self.log("val_loss", self.loss(y_logits, y_true), prog_bar=True) # 3. Compute accuracy: self.log("val_acc", self.valid_acc(y_scores, y_true.int()), prog_bar=True) def configure_optimizers(self): parameters = list(self.parameters()) trainable_parameters = list(filter(lambda p: p.requires_grad, parameters)) rank_zero_info( f"The model will start training with only {len(trainable_parameters)} " f"trainable parameters out of {len(parameters)}." ) optimizer = optim.Adam(trainable_parameters, lr=self.lr) scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma) return [optimizer], [scheduler] class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_lightning_class_args(MilestonesFinetuning, "finetuning") parser.link_arguments("data.batch_size", "model.batch_size") parser.link_arguments("finetuning.milestones", "model.milestones") parser.link_arguments("finetuning.train_bn", "model.train_bn") parser.set_defaults( { "trainer.max_epochs": 15, "trainer.enable_model_summary": False, "trainer.num_sanity_val_steps": 0, } ) def cli_main(): MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234) if __name__ == "__main__": cli_lightning_logo() cli_main()