diff --git a/.gitignore b/.gitignore index 9fcf0e1e29..390551b8f6 100644 --- a/.gitignore +++ b/.gitignore @@ -153,4 +153,5 @@ wandb cifar-10-batches-py *.pt # ctags -tags \ No newline at end of file +tags +data diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index ea4e9e0275..65bf1bde14 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -37,12 +37,12 @@ the classifier is trained with lr = 1e-4. Note: See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html """ - import argparse +import os from pathlib import Path -from tempfile import TemporaryDirectory from typing import Union +import torch import torch.nn.functional as F from torch import nn, optim from torch.optim.lr_scheduler import MultiStepLR @@ -55,52 +55,114 @@ from torchvision.datasets.utils import download_and_extract_archive import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import _logger as log +from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning +from pytorch_lightning.utilities import rank_zero_info DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" # --- Finetuning Callback --- -class MilestonesFinetuningCallback(BaseFinetuning): +class MilestonesFinetuning(BaseFinetuning): - def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True): + def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False): self.milestones = milestones self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule): - self.freeze(module=pl_module.feature_extractor, train_bn=self.train_bn) + self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn) def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): if epoch == self.milestones[0]: # unfreeze 5 last layers self.unfreeze_and_add_param_group( - module=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn + modules=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn ) elif epoch == self.milestones[1]: # unfreeze remaing layers self.unfreeze_and_add_param_group( - module=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn + modules=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn ) +class CatDogImageDataModule(LightningDataModule): + + def __init__( + self, + dl_path: Union[str, Path], + num_workers: int = 0, + batch_size: int = 8, + ): + super().__init__() + + self._dl_path = dl_path + self._num_workers = num_workers + self._batch_size = batch_size + + def prepare_data(self): + """Download images and prepare images datasets.""" + download_and_extract_archive(url=DATA_URL, download_root=self._dl_path, remove_finished=True) + + @property + def data_path(self): + return Path(self._dl_path).joinpath("cats_and_dogs_filtered") + + @property + def normalize_transform(self): + return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + @property + def train_transform(self): + return transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), self.normalize_transform + ]) + + @property + def valid_transform(self): + return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), self.normalize_transform]) + + def create_dataset(self, root, transform): + return ImageFolder(root=root, transform=transform) + + def __dataloader(self, train: bool): + """Train/validation loaders.""" + if train: + dataset = self.create_dataset(self.data_path.joinpath("train"), self.train_transform) + else: + dataset = self.create_dataset(self.data_path.joinpath("validation"), self.valid_transform) + return DataLoader(dataset=dataset, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=train) + + def train_dataloader(self): + log.info("Training data loaded.") + return self.__dataloader(train=True) + + def val_dataloader(self): + log.info("Validation data loaded.") + return self.__dataloader(train=False) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = argparse.ArgumentParser(parents=[parent_parser]) + parser.add_argument( + "--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers" + ) + parser.add_argument( + "--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size" + ) + return parser + + # --- Pytorch-lightning module --- class TransferLearningModel(pl.LightningModule): - """Transfer Learning with pre-trained ResNet50. - >>> with TemporaryDirectory(dir='.') as tmp_dir: - ... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - TransferLearningModel( - (feature_extractor): Sequential(...) - (fc): Sequential(...) - ) - """ def __init__( self, - dl_path: Union[str, Path], backbone: str = "resnet50", train_bn: bool = True, milestones: tuple = (5, 10), @@ -115,7 +177,6 @@ class TransferLearningModel(pl.LightningModule): dl_path: Path where the data will be downloaded """ super().__init__() - self.dl_path = dl_path self.backbone = backbone self.train_bn = train_bn self.milestones = milestones @@ -124,7 +185,6 @@ class TransferLearningModel(pl.LightningModule): self.lr_scheduler_gamma = lr_scheduler_gamma self.num_workers = num_workers - self.dl_path = dl_path self.__build_model() self.train_acc = pl.metrics.Accuracy() @@ -163,7 +223,7 @@ class TransferLearningModel(pl.LightningModule): # 2. Classifier (returns logits): x = self.fc(x) - return F.sigmoid(x) + return torch.sigmoid(x) def loss(self, logits, labels): return self.loss_func(input=logits, target=labels) @@ -195,60 +255,16 @@ class TransferLearningModel(pl.LightningModule): self.log("val_acc", self.valid_acc(y_logits, y_true.int()), prog_bar=True) def configure_optimizers(self): - optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr) - + parameters = list(self.parameters()) + trainable_parameters = list(filter(lambda p: p.requires_grad, parameters)) + rank_zero_info( + f"The model will start training with only {len(trainable_parameters)} " + f"trainable parameters out of {len(parameters)}." + ) + optimizer = optim.Adam(trainable_parameters, lr=self.lr) scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma) - return [optimizer], [scheduler] - def prepare_data(self): - """Download images and prepare images datasets.""" - download_and_extract_archive(url=DATA_URL, download_root=self.dl_path, remove_finished=True) - - def setup(self, stage: str): - data_path = Path(self.dl_path).joinpath("cats_and_dogs_filtered") - - # 2. Load the data + preprocessing & data augmentation - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - - train_dataset = ImageFolder( - root=data_path.joinpath("train"), - transform=transforms.Compose([ - transforms.Resize((224, 224)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ]), - ) - - valid_dataset = ImageFolder( - root=data_path.joinpath("validation"), - transform=transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - normalize, - ]), - ) - - self.train_dataset = train_dataset - self.valid_dataset = valid_dataset - - def __dataloader(self, train: bool): - """Train/validation loaders.""" - - _dataset = self.train_dataset if train else self.valid_dataset - loader = DataLoader(dataset=_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=train) - - return loader - - def train_dataloader(self): - log.info("Training data loaded.") - return self.__dataloader(train=True) - - def val_dataloader(self): - log.info("Validation data loaded.") - return self.__dataloader(train=False) - @staticmethod def add_model_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser]) @@ -263,7 +279,7 @@ class TransferLearningModel(pl.LightningModule): "--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" ) parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size") - parser.add_argument("--gpus", type=int, default=1, help="number of gpus to use") + parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use") parser.add_argument( "--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr" ) @@ -275,12 +291,9 @@ class TransferLearningModel(pl.LightningModule): help="Factor by which the learning rate is reduced at each milestone", dest="lr_scheduler_gamma", ) - parser.add_argument( - "--num-workers", default=6, type=int, metavar="W", help="number of CPU workers", dest="num_workers" - ) parser.add_argument( "--train-bn", - default=True, + default=False, type=bool, metavar="TB", help="Whether the BatchNorm layers should be trainable", @@ -303,21 +316,22 @@ def main(args: argparse.Namespace) -> None: to a temporary directory. """ - with TemporaryDirectory(dir=args.root_data_path) as tmp_dir: + datamodule = CatDogImageDataModule( + dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers + ) + model = TransferLearningModel(**vars(args)) + finetuning_callback = MilestonesFinetuning(milestones=args.milestones) - model = TransferLearningModel(dl_path=tmp_dir, **vars(args)) - finetuning_callback = MilestonesFinetuningCallback(milestones=args.milestones) + trainer = pl.Trainer( + weights_summary=None, + progress_bar_refresh_rate=1, + num_sanity_val_steps=0, + gpus=args.gpus, + max_epochs=args.nb_epochs, + callbacks=[finetuning_callback] + ) - trainer = pl.Trainer( - weights_summary=None, - progress_bar_refresh_rate=1, - num_sanity_val_steps=0, - gpus=args.gpus, - max_epochs=args.nb_epochs, - callbacks=[finetuning_callback] - ) - - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) def get_args() -> argparse.Namespace: @@ -331,6 +345,7 @@ def get_args() -> argparse.Namespace: dest="root_data_path", ) parser = TransferLearningModel.add_model_specific_args(parent_parser) + parser = CatDogImageDataModule.add_argparse_args(parser) return parser.parse_args()