[BugFix] Resolve bugs in computer_vision_fine_tuning.py example (#5985)

* update the script to use DataModule

* add message at for the frozen parameters

* add message about trainable parameters

* resolve flake8
This commit is contained in:
chaton 2021-02-16 21:01:04 +00:00 committed by GitHub
parent 6e79bef996
commit 141316fb29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 105 additions and 89 deletions

3
.gitignore vendored
View File

@ -153,4 +153,5 @@ wandb
cifar-10-batches-py cifar-10-batches-py
*.pt *.pt
# ctags # ctags
tags tags
data

View File

@ -37,12 +37,12 @@ the classifier is trained with lr = 1e-4.
Note: Note:
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
""" """
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Union from typing import Union
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, optim from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
@ -55,52 +55,114 @@ from torchvision.datasets.utils import download_and_extract_archive
import pytorch_lightning as pl import pytorch_lightning as pl
from pl_examples import cli_lightning_logo from pl_examples import cli_lightning_logo
from pytorch_lightning import _logger as log from pytorch_lightning import _logger as log
from pytorch_lightning import LightningDataModule
from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
# --- Finetuning Callback --- # --- Finetuning Callback ---
class MilestonesFinetuningCallback(BaseFinetuning): class MilestonesFinetuning(BaseFinetuning):
def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True): def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False):
self.milestones = milestones self.milestones = milestones
self.train_bn = train_bn self.train_bn = train_bn
def freeze_before_training(self, pl_module: pl.LightningModule): def freeze_before_training(self, pl_module: pl.LightningModule):
self.freeze(module=pl_module.feature_extractor, train_bn=self.train_bn) self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn)
def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
if epoch == self.milestones[0]: if epoch == self.milestones[0]:
# unfreeze 5 last layers # unfreeze 5 last layers
self.unfreeze_and_add_param_group( self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn modules=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
) )
elif epoch == self.milestones[1]: elif epoch == self.milestones[1]:
# unfreeze remaing layers # unfreeze remaing layers
self.unfreeze_and_add_param_group( self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn modules=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
) )
class CatDogImageDataModule(LightningDataModule):
def __init__(
self,
dl_path: Union[str, Path],
num_workers: int = 0,
batch_size: int = 8,
):
super().__init__()
self._dl_path = dl_path
self._num_workers = num_workers
self._batch_size = batch_size
def prepare_data(self):
"""Download images and prepare images datasets."""
download_and_extract_archive(url=DATA_URL, download_root=self._dl_path, remove_finished=True)
@property
def data_path(self):
return Path(self._dl_path).joinpath("cats_and_dogs_filtered")
@property
def normalize_transform(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
@property
def train_transform(self):
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), self.normalize_transform
])
@property
def valid_transform(self):
return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), self.normalize_transform])
def create_dataset(self, root, transform):
return ImageFolder(root=root, transform=transform)
def __dataloader(self, train: bool):
"""Train/validation loaders."""
if train:
dataset = self.create_dataset(self.data_path.joinpath("train"), self.train_transform)
else:
dataset = self.create_dataset(self.data_path.joinpath("validation"), self.valid_transform)
return DataLoader(dataset=dataset, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=train)
def train_dataloader(self):
log.info("Training data loaded.")
return self.__dataloader(train=True)
def val_dataloader(self):
log.info("Validation data loaded.")
return self.__dataloader(train=False)
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser])
parser.add_argument(
"--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
)
parser.add_argument(
"--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size"
)
return parser
# --- Pytorch-lightning module --- # --- Pytorch-lightning module ---
class TransferLearningModel(pl.LightningModule): class TransferLearningModel(pl.LightningModule):
"""Transfer Learning with pre-trained ResNet50.
>>> with TemporaryDirectory(dir='.') as tmp_dir:
... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
TransferLearningModel(
(feature_extractor): Sequential(...)
(fc): Sequential(...)
)
"""
def __init__( def __init__(
self, self,
dl_path: Union[str, Path],
backbone: str = "resnet50", backbone: str = "resnet50",
train_bn: bool = True, train_bn: bool = True,
milestones: tuple = (5, 10), milestones: tuple = (5, 10),
@ -115,7 +177,6 @@ class TransferLearningModel(pl.LightningModule):
dl_path: Path where the data will be downloaded dl_path: Path where the data will be downloaded
""" """
super().__init__() super().__init__()
self.dl_path = dl_path
self.backbone = backbone self.backbone = backbone
self.train_bn = train_bn self.train_bn = train_bn
self.milestones = milestones self.milestones = milestones
@ -124,7 +185,6 @@ class TransferLearningModel(pl.LightningModule):
self.lr_scheduler_gamma = lr_scheduler_gamma self.lr_scheduler_gamma = lr_scheduler_gamma
self.num_workers = num_workers self.num_workers = num_workers
self.dl_path = dl_path
self.__build_model() self.__build_model()
self.train_acc = pl.metrics.Accuracy() self.train_acc = pl.metrics.Accuracy()
@ -163,7 +223,7 @@ class TransferLearningModel(pl.LightningModule):
# 2. Classifier (returns logits): # 2. Classifier (returns logits):
x = self.fc(x) x = self.fc(x)
return F.sigmoid(x) return torch.sigmoid(x)
def loss(self, logits, labels): def loss(self, logits, labels):
return self.loss_func(input=logits, target=labels) return self.loss_func(input=logits, target=labels)
@ -195,60 +255,16 @@ class TransferLearningModel(pl.LightningModule):
self.log("val_acc", self.valid_acc(y_logits, y_true.int()), prog_bar=True) self.log("val_acc", self.valid_acc(y_logits, y_true.int()), prog_bar=True)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr) parameters = list(self.parameters())
trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
rank_zero_info(
f"The model will start training with only {len(trainable_parameters)} "
f"trainable parameters out of {len(parameters)}."
)
optimizer = optim.Adam(trainable_parameters, lr=self.lr)
scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma) scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma)
return [optimizer], [scheduler] return [optimizer], [scheduler]
def prepare_data(self):
"""Download images and prepare images datasets."""
download_and_extract_archive(url=DATA_URL, download_root=self.dl_path, remove_finished=True)
def setup(self, stage: str):
data_path = Path(self.dl_path).joinpath("cats_and_dogs_filtered")
# 2. Load the data + preprocessing & data augmentation
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = ImageFolder(
root=data_path.joinpath("train"),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
valid_dataset = ImageFolder(
root=data_path.joinpath("validation"),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]),
)
self.train_dataset = train_dataset
self.valid_dataset = valid_dataset
def __dataloader(self, train: bool):
"""Train/validation loaders."""
_dataset = self.train_dataset if train else self.valid_dataset
loader = DataLoader(dataset=_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=train)
return loader
def train_dataloader(self):
log.info("Training data loaded.")
return self.__dataloader(train=True)
def val_dataloader(self):
log.info("Validation data loaded.")
return self.__dataloader(train=False)
@staticmethod @staticmethod
def add_model_specific_args(parent_parser): def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser]) parser = argparse.ArgumentParser(parents=[parent_parser])
@ -263,7 +279,7 @@ class TransferLearningModel(pl.LightningModule):
"--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" "--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs"
) )
parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size") parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size")
parser.add_argument("--gpus", type=int, default=1, help="number of gpus to use") parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use")
parser.add_argument( parser.add_argument(
"--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr" "--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr"
) )
@ -275,12 +291,9 @@ class TransferLearningModel(pl.LightningModule):
help="Factor by which the learning rate is reduced at each milestone", help="Factor by which the learning rate is reduced at each milestone",
dest="lr_scheduler_gamma", dest="lr_scheduler_gamma",
) )
parser.add_argument(
"--num-workers", default=6, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
)
parser.add_argument( parser.add_argument(
"--train-bn", "--train-bn",
default=True, default=False,
type=bool, type=bool,
metavar="TB", metavar="TB",
help="Whether the BatchNorm layers should be trainable", help="Whether the BatchNorm layers should be trainable",
@ -303,21 +316,22 @@ def main(args: argparse.Namespace) -> None:
to a temporary directory. to a temporary directory.
""" """
with TemporaryDirectory(dir=args.root_data_path) as tmp_dir: datamodule = CatDogImageDataModule(
dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers
)
model = TransferLearningModel(**vars(args))
finetuning_callback = MilestonesFinetuning(milestones=args.milestones)
model = TransferLearningModel(dl_path=tmp_dir, **vars(args)) trainer = pl.Trainer(
finetuning_callback = MilestonesFinetuningCallback(milestones=args.milestones) weights_summary=None,
progress_bar_refresh_rate=1,
num_sanity_val_steps=0,
gpus=args.gpus,
max_epochs=args.nb_epochs,
callbacks=[finetuning_callback]
)
trainer = pl.Trainer( trainer.fit(model, datamodule=datamodule)
weights_summary=None,
progress_bar_refresh_rate=1,
num_sanity_val_steps=0,
gpus=args.gpus,
max_epochs=args.nb_epochs,
callbacks=[finetuning_callback]
)
trainer.fit(model)
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
@ -331,6 +345,7 @@ def get_args() -> argparse.Namespace:
dest="root_data_path", dest="root_data_path",
) )
parser = TransferLearningModel.add_model_specific_args(parent_parser) parser = TransferLearningModel.add_model_specific_args(parent_parser)
parser = CatDogImageDataModule.add_argparse_args(parser)
return parser.parse_args() return parser.parse_args()