Add KFold Loop example (#9965)
This commit is contained in:
parent
a99b7440b5
commit
86df7dcee7
|
@ -186,14 +186,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
|
||||
|
||||
|
||||
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
|
||||
|
||||
|
||||
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
|
||||
|
||||
|
||||
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
|
||||
|
||||
|
||||
- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965))
|
||||
|
||||
|
||||
- LightningLite:
|
||||
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
|
||||
|
||||
|
|
|
@ -395,7 +395,17 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
|
|||
# 5. Save the model!
|
||||
trainer.save_checkpoint("image_classification_model.pt")
|
||||
|
||||
Here is the `runnable example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py#L31>`_.
|
||||
Here is the `Active Learning Loop example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py#L31>`_.
|
||||
|
||||
|
||||
`KFold / Cross Validation <https://en.wikipedia.org/wiki/Cross-validation_(statistics)>`__ is a machine learning practice in which the training dataset is being partitioned into `num_folds` complementary subsets.
|
||||
One cross validation round will perform fitting where one fold is left out for validation and the other folds are used for training.
|
||||
To reduce variability, once all rounds are performed using the different folds, the trained models are ensembled and their predictions are
|
||||
averaged when estimating the model's predictive performance on the test dataset.
|
||||
KFold can elegantly be implemented with `Lightning Loop Customization` as follows:
|
||||
|
||||
Here is the `KFold Loop example <https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/loops/kfold.py>`_.
|
||||
|
||||
|
||||
Advanced Topics and Examples
|
||||
----------------------------
|
||||
|
|
|
@ -26,18 +26,21 @@ from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
|||
if _TORCHVISION_AVAILABLE:
|
||||
from torchvision import transforms as transform_lib
|
||||
|
||||
_TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
|
||||
if _TORCHVISION_MNIST_AVAILABLE:
|
||||
try:
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
MNIST(_DATASETS_PATH, download=True)
|
||||
except HTTPError as e:
|
||||
print(f"Error {e} downloading `torchvision.datasets.MNIST`")
|
||||
_TORCHVISION_MNIST_AVAILABLE = False
|
||||
if not _TORCHVISION_MNIST_AVAILABLE:
|
||||
print("`torchvision.datasets.MNIST` not available. Using our hosted version")
|
||||
from tests.helpers.datasets import MNIST
|
||||
def MNIST(*args, **kwargs):
|
||||
torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
|
||||
if torchvision_mnist_available:
|
||||
try:
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
MNIST(_DATASETS_PATH, download=True)
|
||||
except HTTPError as e:
|
||||
print(f"Error {e} downloading `torchvision.datasets.MNIST`")
|
||||
torchvision_mnist_available = False
|
||||
if not torchvision_mnist_available:
|
||||
print("`torchvision.datasets.MNIST` not available. Using our hosted version")
|
||||
from tests.helpers.datasets import MNIST
|
||||
return MNIST(*args, **kwargs)
|
||||
|
||||
|
||||
class MNISTDataModule(LightningDataModule):
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from sklearn.model_selection import KFold
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import random_split
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset, Subset
|
||||
|
||||
from pl_examples import _DATASETS_PATH
|
||||
from pl_examples.basic_examples.mnist_datamodule import MNIST
|
||||
from pl_examples.basic_examples.simple_image_classifier import LitClassifier
|
||||
from pytorch_lightning import LightningDataModule, seed_everything, Trainer
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
||||
#############################################################################################
|
||||
# KFold Loop / Cross Validation Example #
|
||||
# This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 #
|
||||
# Learn more about the loop structure from the documentation: #
|
||||
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html #
|
||||
#############################################################################################
|
||||
|
||||
|
||||
seed_everything(42)
|
||||
|
||||
|
||||
#############################################################################################
|
||||
# Step 1 / 5: Define KFold DataModule API #
|
||||
# Our KFold DataModule requires to implement the `setup_folds` and `setup_fold_index` #
|
||||
# methods. #
|
||||
#############################################################################################
|
||||
|
||||
|
||||
class BaseKFoldDataModule(LightningDataModule, ABC):
|
||||
@abstractmethod
|
||||
def setup_folds(self, num_folds: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_fold_index(self, fold_index: int) -> None:
|
||||
pass
|
||||
|
||||
|
||||
#############################################################################################
|
||||
# Step 2 / 5: Implement the KFoldDataModule #
|
||||
# The `KFoldDataModule` will take a train and test dataset. #
|
||||
# On `setup_folds`, folds will be created depending on the provided argument `num_folds` #
|
||||
# Our `setup_fold_index`, the provided train dataset will be splitted accordingly to #
|
||||
# the current fold split. #
|
||||
#############################################################################################
|
||||
|
||||
|
||||
@dataclass
|
||||
class MNISTKFoldDataModule(BaseKFoldDataModule):
|
||||
|
||||
train_dataset: Optional[Dataset] = None
|
||||
test_dataset: Optional[Dataset] = None
|
||||
train_fold: Optional[Dataset] = None
|
||||
val_fold: Optional[Dataset] = None
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
# download the data.
|
||||
MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
# load the data
|
||||
dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
|
||||
self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000])
|
||||
|
||||
def setup_folds(self, num_folds: int) -> None:
|
||||
self.num_folds = num_folds
|
||||
self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))]
|
||||
|
||||
def setup_fold_index(self, fold_index: int) -> None:
|
||||
train_indices, val_indices = self.splits[fold_index]
|
||||
self.train_fold = Subset(self.train_dataset, train_indices)
|
||||
self.val_fold = Subset(self.train_dataset, val_indices)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return DataLoader(self.train_fold)
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return DataLoader(self.val_fold)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return DataLoader(self.test_dataset)
|
||||
|
||||
|
||||
#############################################################################################
|
||||
# Step 3 / 5: Implement the EnsembleVotingModel module #
|
||||
# The `EnsembleVotingModel` will take our custom LightningModule and #
|
||||
# several checkpoint_paths. #
|
||||
# #
|
||||
#############################################################################################
|
||||
|
||||
|
||||
class EnsembleVotingModel(LightningModule):
|
||||
def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]):
|
||||
super().__init__()
|
||||
# Create `num_folds` models with their associated fold weights
|
||||
self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
|
||||
|
||||
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
||||
# Compute the averaged predictions over the `num_folds` models.
|
||||
logits = torch.stack([m(batch[0]) for m in self.models]).mean(0)
|
||||
loss = F.cross_entropy(logits, batch[1])
|
||||
self.log("test_loss", loss)
|
||||
|
||||
|
||||
#############################################################################################
|
||||
# Step 4 / 5: Implement the KFoldLoop #
|
||||
# From Lightning v1.5, it is possible to implement your own loop. There is several steps #
|
||||
# to do so which are described in detail within the documentation #
|
||||
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. #
|
||||
# Here, we will implement an outer fit_loop. It means we will implement subclass the #
|
||||
# base Loop and wrap the current trainer `fit_loop`. #
|
||||
#############################################################################################
|
||||
|
||||
|
||||
#############################################################################################
|
||||
# Here is the `Pseudo Code` for the base Loop. #
|
||||
# class Loop: #
|
||||
# #
|
||||
# def run(self, ...): #
|
||||
# self.reset(...) #
|
||||
# self.on_run_start(...) #
|
||||
# #
|
||||
# while not self.done: #
|
||||
# self.on_advance_start(...) #
|
||||
# self.advance(...) #
|
||||
# self.on_advance_end(...) #
|
||||
# #
|
||||
# return self.on_run_end(...) #
|
||||
#############################################################################################
|
||||
|
||||
|
||||
class KFoldLoop(Loop):
|
||||
def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str):
|
||||
super().__init__()
|
||||
self.num_folds = num_folds
|
||||
self.fit_loop = fit_loop
|
||||
self.current_fold: int = 0
|
||||
self.export_path = export_path
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self.current_fold >= self.num_folds
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Nothing to reset in this loop."""
|
||||
|
||||
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
|
||||
model."""
|
||||
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
|
||||
self.trainer.datamodule.setup_folds(self.num_folds)
|
||||
self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
|
||||
|
||||
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
|
||||
print(f"STARTING FOLD {self.current_fold}")
|
||||
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
|
||||
self.trainer.datamodule.setup_fold_index(self.current_fold)
|
||||
|
||||
def advance(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Used to the run a fitting and testing on the current hold."""
|
||||
self._reset_fitting() # requires to reset the tracking stage.
|
||||
self.fit_loop.run()
|
||||
|
||||
self._reset_testing() # requires to reset the tracking stage.
|
||||
self.trainer.test_loop.run()
|
||||
self.current_fold += 1 # increment fold tracking number.
|
||||
|
||||
def on_advance_end(self) -> None:
|
||||
"""Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
|
||||
self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
|
||||
# restore the original weights + optimizers and schedulers.
|
||||
self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
|
||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||
|
||||
def on_run_end(self) -> None:
|
||||
"""Used to compute the performance of the ensemble model on the test set."""
|
||||
checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
|
||||
voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
|
||||
voting_model.trainer = self.trainer
|
||||
# This requires to connect the new model and move it the right device.
|
||||
self.trainer.accelerator.connect(voting_model)
|
||||
self.trainer.training_type_plugin.model_to_device()
|
||||
self.trainer.test_loop.run()
|
||||
|
||||
def on_save_checkpoint(self) -> Dict[str, int]:
|
||||
return {"current_fold": self.current_fold}
|
||||
|
||||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
self.current_fold = state_dict["current_fold"]
|
||||
|
||||
def _reset_fitting(self) -> None:
|
||||
self.trainer.reset_train_dataloader()
|
||||
self.trainer.reset_val_dataloader()
|
||||
self.trainer.state.fn = TrainerFn.FITTING
|
||||
self.trainer.training = True
|
||||
|
||||
def _reset_testing(self) -> None:
|
||||
self.trainer.reset_test_dataloader()
|
||||
self.trainer.state.fn = TrainerFn.TESTING
|
||||
self.trainer.testing = True
|
||||
|
||||
def __getattr__(self, key) -> Any:
|
||||
# requires to be overridden as attributes of the wrapped loop are being accessed.
|
||||
if key not in self.__dict__:
|
||||
return getattr(self.fit_loop, key)
|
||||
return self.__dict__[key]
|
||||
|
||||
|
||||
#############################################################################################
|
||||
# Step 5 / 5: Connect the KFoldLoop to the Trainer #
|
||||
# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to #
|
||||
# the Trainer. #
|
||||
# Finally, use `trainer.fit` to start the cross validation training. #
|
||||
#############################################################################################
|
||||
|
||||
model = LitClassifier()
|
||||
datamodule = MNISTKFoldDataModule()
|
||||
trainer = Trainer(
|
||||
max_epochs=10,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
num_sanity_val_steps=0,
|
||||
devices=1,
|
||||
accelerator="auto",
|
||||
strategy="ddp",
|
||||
)
|
||||
trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./")
|
||||
trainer.fit(model, datamodule)
|
|
@ -1225,7 +1225,8 @@ class Trainer(
|
|||
# reload data when needed
|
||||
model = self.lightning_module
|
||||
|
||||
self.reset_train_val_dataloaders(model)
|
||||
if isinstance(self.fit_loop, FitLoop):
|
||||
self.reset_train_val_dataloaders(model)
|
||||
|
||||
self.fit_loop.trainer = self
|
||||
with torch.autograd.set_detect_anomaly(self._detect_anomaly):
|
||||
|
|
Loading…
Reference in New Issue