lightning/examples/pl_loops/kfold.py

317 lines
15 KiB
Python
Raw Normal View History

2021-10-18 15:27:12 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from os import path
2021-10-18 15:27:12 +00:00
from typing import Any, Dict, List, Optional, Type
import torch
import torchvision.transforms as T
from sklearn.model_selection import KFold
from torch.nn import functional as F
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, Subset
2022-01-03 09:54:44 +00:00
from torchmetrics.classification.accuracy import Accuracy
2021-10-18 15:27:12 +00:00
from pytorch_lightning import LightningDataModule, seed_everything, Trainer
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST
2021-10-18 15:27:12 +00:00
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.loops.loop import Loop
2021-10-18 15:27:12 +00:00
from pytorch_lightning.trainer.states import TrainerFn
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
2021-10-18 15:27:12 +00:00
#############################################################################################
# KFold Loop / Cross Validation Example #
# This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 #
# Learn more about the loop structure from the documentation: #
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html #
#############################################################################################
#############################################################################################
# Step 1 / 5: Define KFold DataModule API #
# Our KFold DataModule requires to implement the `setup_folds` and `setup_fold_index` #
# methods. #
#############################################################################################
class BaseKFoldDataModule(LightningDataModule, ABC):
@abstractmethod
def setup_folds(self, num_folds: int) -> None:
pass
@abstractmethod
def setup_fold_index(self, fold_index: int) -> None:
pass
#############################################################################################
# Step 2 / 5: Implement the KFoldDataModule #
# The `KFoldDataModule` will take a train and test dataset. #
# On `setup_folds`, folds will be created depending on the provided argument `num_folds` #
2022-02-17 01:27:51 +00:00
# Our `setup_fold_index`, the provided train dataset will be split accordingly to #
2021-10-18 15:27:12 +00:00
# the current fold split. #
#############################################################################################
@dataclass
class MNISTKFoldDataModule(BaseKFoldDataModule):
train_dataset: Optional[Dataset] = None
test_dataset: Optional[Dataset] = None
train_fold: Optional[Dataset] = None
val_fold: Optional[Dataset] = None
def prepare_data(self) -> None:
# download the data.
MNIST(DATASETS_PATH, download=True, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
2021-10-18 15:27:12 +00:00
def setup(self, stage: str) -> None:
2021-10-18 15:27:12 +00:00
# load the data
dataset = MNIST(DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
2021-10-18 15:27:12 +00:00
self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000])
def setup_folds(self, num_folds: int) -> None:
self.num_folds = num_folds
self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))]
def setup_fold_index(self, fold_index: int) -> None:
train_indices, val_indices = self.splits[fold_index]
self.train_fold = Subset(self.train_dataset, train_indices)
self.val_fold = Subset(self.train_dataset, val_indices)
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_fold)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_fold)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_dataset)
2022-01-03 09:54:44 +00:00
def __post_init__(cls):
super().__init__()
2021-10-18 15:27:12 +00:00
#############################################################################################
# Step 3 / 5: Implement the EnsembleVotingModel module #
# The `EnsembleVotingModel` will take our custom LightningModule and #
# several checkpoint_paths. #
# #
#############################################################################################
class EnsembleVotingModel(LightningModule):
2022-01-03 09:54:44 +00:00
def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]) -> None:
2021-10-18 15:27:12 +00:00
super().__init__()
# Create `num_folds` models with their associated fold weights
self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
2022-01-03 09:54:44 +00:00
self.test_acc = Accuracy()
2021-10-18 15:27:12 +00:00
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
# Compute the averaged predictions over the `num_folds` models.
logits = torch.stack([m(batch[0]) for m in self.models]).mean(0)
2022-01-03 09:54:44 +00:00
loss = F.nll_loss(logits, batch[1])
self.test_acc(logits, batch[1])
self.log("test_acc", self.test_acc)
2021-10-18 15:27:12 +00:00
self.log("test_loss", loss)
#############################################################################################
# Step 4 / 5: Implement the KFoldLoop #
# From Lightning v1.5, it is possible to implement your own loop. There is several steps #
# to do so which are described in detail within the documentation #
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. #
# Here, we will implement an outer fit_loop. It means we will implement subclass the #
# base Loop and wrap the current trainer `fit_loop`. #
#############################################################################################
#############################################################################################
# Here is the `Pseudo Code` for the base Loop. #
# class Loop: #
# #
# def run(self, ...): #
# self.reset(...) #
# self.on_run_start(...) #
# #
# while not self.done: #
# self.on_advance_start(...) #
# self.advance(...) #
# self.on_advance_end(...) #
# #
# return self.on_run_end(...) #
#############################################################################################
class KFoldLoop(Loop):
2022-01-03 09:54:44 +00:00
def __init__(self, num_folds: int, export_path: str) -> None:
2021-10-18 15:27:12 +00:00
super().__init__()
self.num_folds = num_folds
self.current_fold: int = 0
self.export_path = export_path
@property
def done(self) -> bool:
return self.current_fold >= self.num_folds
2022-01-03 09:54:44 +00:00
def connect(self, fit_loop: FitLoop) -> None:
self.fit_loop = fit_loop
2021-10-18 15:27:12 +00:00
def reset(self) -> None:
"""Nothing to reset in this loop."""
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
model."""
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
self.trainer.datamodule.setup_folds(self.num_folds)
self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
"""Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
print(f"STARTING FOLD {self.current_fold}")
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
self.trainer.datamodule.setup_fold_index(self.current_fold)
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Used to the run a fitting and testing on the current hold."""
self._reset_fitting() # requires to reset the tracking stage.
self.fit_loop.run()
self._reset_testing() # requires to reset the tracking stage.
# the test loop normally expects the model to be the pure LightningModule, but since we are running the
# test loop during fitting, we need to temporarily unpack the wrapped module
wrapped_model = self.trainer.strategy.model
self.trainer.strategy.model = self.trainer.strategy.lightning_module
2021-10-18 15:27:12 +00:00
self.trainer.test_loop.run()
self.trainer.strategy.model = wrapped_model
2021-10-18 15:27:12 +00:00
self.current_fold += 1 # increment fold tracking number.
def on_advance_end(self) -> None:
"""Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
# restore the original weights + optimizers and schedulers.
self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
2022-01-03 09:54:44 +00:00
self.trainer.strategy.setup_optimizers(self.trainer)
self.replace(fit_loop=FitLoop)
2021-10-18 15:27:12 +00:00
def on_run_end(self) -> None:
"""Used to compute the performance of the ensemble model on the test set."""
checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
voting_model.trainer = self.trainer
# This requires to connect the new model and move it the right device.
self.trainer.strategy.connect(voting_model)
self.trainer.strategy.model_to_device()
2021-10-18 15:27:12 +00:00
self.trainer.test_loop.run()
def on_save_checkpoint(self) -> Dict[str, int]:
return {"current_fold": self.current_fold}
def on_load_checkpoint(self, state_dict: Dict) -> None:
self.current_fold = state_dict["current_fold"]
def _reset_fitting(self) -> None:
self.trainer.reset_train_dataloader()
self.trainer.reset_val_dataloader()
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
def _reset_testing(self) -> None:
self.trainer.reset_test_dataloader()
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.testing = True
def __getattr__(self, key) -> Any:
# requires to be overridden as attributes of the wrapped loop are being accessed.
if key not in self.__dict__:
return getattr(self.fit_loop, key)
return self.__dict__[key]
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
2021-10-18 15:27:12 +00:00
class LitImageClassifier(LightningModule):
def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32):
2022-01-03 09:54:44 +00:00
super().__init__()
self.save_hyperparameters(ignore="model")
self.model = model or Net()
self.test_acc = Accuracy()
2022-01-03 09:54:44 +00:00
self.val_acc = Accuracy()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
return loss
2022-01-03 09:54:44 +00:00
def validation_step(self, batch: Any, batch_idx: int) -> None:
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
self.val_acc(logits, y)
self.log("val_acc", self.val_acc)
self.log("val_loss", loss)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
self.test_acc(logits, y)
self.log("test_acc", self.test_acc)
self.log("test_loss", loss)
def configure_optimizers(self):
optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr)
return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)]
2022-01-03 09:54:44 +00:00
2021-10-18 15:27:12 +00:00
#############################################################################################
# Step 5 / 5: Connect the KFoldLoop to the Trainer #
# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to #
# the Trainer. #
# Finally, use `trainer.fit` to start the cross validation training. #
#############################################################################################
if __name__ == "__main__":
2022-01-03 09:54:44 +00:00
seed_everything(42)
model = LitImageClassifier()
datamodule = MNISTKFoldDataModule()
trainer = Trainer(
max_epochs=10,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
num_sanity_val_steps=0,
devices=2,
accelerator="cpu",
strategy="ddp",
)
2022-01-03 09:54:44 +00:00
internal_fit_loop = trainer.fit_loop
trainer.fit_loop = KFoldLoop(5, export_path="./")
trainer.fit_loop.connect(internal_fit_loop)
trainer.fit(model, datamodule)