diff --git a/CHANGELOG.md b/CHANGELOG.md index 17771756bd..888d22a520 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -186,14 +186,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848)) + - Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895)) + - Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699)) - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) +- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965)) + + - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst index b83a64d2f6..d779f16c45 100644 --- a/docs/source/extensions/loops.rst +++ b/docs/source/extensions/loops.rst @@ -395,7 +395,17 @@ To run the following demo, install Flash and `BaaL `_ and the `code for the active learning loop `_. +Here is the `Active Learning Loop example `_ and the `code for the active learning loop `_. + + +`KFold / Cross Validation `__ is a machine learning practice in which the training dataset is being partitioned into `num_folds` complementary subsets. +One cross validation round will perform fitting where one fold is left out for validation and the other folds are used for training. +To reduce variability, once all rounds are performed using the different folds, the trained models are ensembled and their predictions are +averaged when estimating the model's predictive performance on the test dataset. +KFold can elegantly be implemented with `Lightning Loop Customization` as follows: + +Here is the `KFold Loop example `_. + Advanced Topics and Examples ---------------------------- diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 68823eeac7..1d2371c702 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -26,18 +26,21 @@ from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib -_TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) -if _TORCHVISION_MNIST_AVAILABLE: - try: - from torchvision.datasets import MNIST - MNIST(_DATASETS_PATH, download=True) - except HTTPError as e: - print(f"Error {e} downloading `torchvision.datasets.MNIST`") - _TORCHVISION_MNIST_AVAILABLE = False -if not _TORCHVISION_MNIST_AVAILABLE: - print("`torchvision.datasets.MNIST` not available. Using our hosted version") - from tests.helpers.datasets import MNIST +def MNIST(*args, **kwargs): + torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) + if torchvision_mnist_available: + try: + from torchvision.datasets import MNIST + + MNIST(_DATASETS_PATH, download=True) + except HTTPError as e: + print(f"Error {e} downloading `torchvision.datasets.MNIST`") + torchvision_mnist_available = False + if not torchvision_mnist_available: + print("`torchvision.datasets.MNIST` not available. Using our hosted version") + from tests.helpers.datasets import MNIST + return MNIST(*args, **kwargs) class MNISTDataModule(LightningDataModule): diff --git a/pl_examples/loop_examples/__init__.py b/pl_examples/loop_examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_examples/loop_examples/kfold.py b/pl_examples/loop_examples/kfold.py new file mode 100644 index 0000000000..630b1f26f3 --- /dev/null +++ b/pl_examples/loop_examples/kfold.py @@ -0,0 +1,256 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path as osp +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch +import torchvision.transforms as T +from sklearn.model_selection import KFold +from torch.nn import functional as F +from torch.utils.data import random_split +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, Subset + +from pl_examples import _DATASETS_PATH +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.simple_image_classifier import LitClassifier +from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.states import TrainerFn + +############################################################################################# +# KFold Loop / Cross Validation Example # +# This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 # +# Learn more about the loop structure from the documentation: # +# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html # +############################################################################################# + + +seed_everything(42) + + +############################################################################################# +# Step 1 / 5: Define KFold DataModule API # +# Our KFold DataModule requires to implement the `setup_folds` and `setup_fold_index` # +# methods. # +############################################################################################# + + +class BaseKFoldDataModule(LightningDataModule, ABC): + @abstractmethod + def setup_folds(self, num_folds: int) -> None: + pass + + @abstractmethod + def setup_fold_index(self, fold_index: int) -> None: + pass + + +############################################################################################# +# Step 2 / 5: Implement the KFoldDataModule # +# The `KFoldDataModule` will take a train and test dataset. # +# On `setup_folds`, folds will be created depending on the provided argument `num_folds` # +# Our `setup_fold_index`, the provided train dataset will be splitted accordingly to # +# the current fold split. # +############################################################################################# + + +@dataclass +class MNISTKFoldDataModule(BaseKFoldDataModule): + + train_dataset: Optional[Dataset] = None + test_dataset: Optional[Dataset] = None + train_fold: Optional[Dataset] = None + val_fold: Optional[Dataset] = None + + def prepare_data(self) -> None: + # download the data. + MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) + + def setup(self, stage: Optional[str] = None) -> None: + # load the data + dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) + self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000]) + + def setup_folds(self, num_folds: int) -> None: + self.num_folds = num_folds + self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))] + + def setup_fold_index(self, fold_index: int) -> None: + train_indices, val_indices = self.splits[fold_index] + self.train_fold = Subset(self.train_dataset, train_indices) + self.val_fold = Subset(self.train_dataset, val_indices) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_fold) + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.val_fold) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.test_dataset) + + +############################################################################################# +# Step 3 / 5: Implement the EnsembleVotingModel module # +# The `EnsembleVotingModel` will take our custom LightningModule and # +# several checkpoint_paths. # +# # +############################################################################################# + + +class EnsembleVotingModel(LightningModule): + def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]): + super().__init__() + # Create `num_folds` models with their associated fold weights + self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + # Compute the averaged predictions over the `num_folds` models. + logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) + loss = F.cross_entropy(logits, batch[1]) + self.log("test_loss", loss) + + +############################################################################################# +# Step 4 / 5: Implement the KFoldLoop # +# From Lightning v1.5, it is possible to implement your own loop. There is several steps # +# to do so which are described in detail within the documentation # +# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. # +# Here, we will implement an outer fit_loop. It means we will implement subclass the # +# base Loop and wrap the current trainer `fit_loop`. # +############################################################################################# + + +############################################################################################# +# Here is the `Pseudo Code` for the base Loop. # +# class Loop: # +# # +# def run(self, ...): # +# self.reset(...) # +# self.on_run_start(...) # +# # +# while not self.done: # +# self.on_advance_start(...) # +# self.advance(...) # +# self.on_advance_end(...) # +# # +# return self.on_run_end(...) # +############################################################################################# + + +class KFoldLoop(Loop): + def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str): + super().__init__() + self.num_folds = num_folds + self.fit_loop = fit_loop + self.current_fold: int = 0 + self.export_path = export_path + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def reset(self) -> None: + """Nothing to reset in this loop.""" + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the + model.""" + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_folds(self.num_folds) + self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" + print(f"STARTING FOLD {self.current_fold}") + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_fold_index(self.current_fold) + + def advance(self, *args: Any, **kwargs: Any) -> None: + """Used to the run a fitting and testing on the current hold.""" + self._reset_fitting() # requires to reset the tracking stage. + self.fit_loop.run() + + self._reset_testing() # requires to reset the tracking stage. + self.trainer.test_loop.run() + self.current_fold += 1 # increment fold tracking number. + + def on_advance_end(self) -> None: + """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" + self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) + # restore the original weights + optimizers and schedulers. + self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) + self.trainer.accelerator.setup_optimizers(self.trainer) + + def on_run_end(self) -> None: + """Used to compute the performance of the ensemble model on the test set.""" + checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] + voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) + voting_model.trainer = self.trainer + # This requires to connect the new model and move it the right device. + self.trainer.accelerator.connect(voting_model) + self.trainer.training_type_plugin.model_to_device() + self.trainer.test_loop.run() + + def on_save_checkpoint(self) -> Dict[str, int]: + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.current_fold = state_dict["current_fold"] + + def _reset_fitting(self) -> None: + self.trainer.reset_train_dataloader() + self.trainer.reset_val_dataloader() + self.trainer.state.fn = TrainerFn.FITTING + self.trainer.training = True + + def _reset_testing(self) -> None: + self.trainer.reset_test_dataloader() + self.trainer.state.fn = TrainerFn.TESTING + self.trainer.testing = True + + def __getattr__(self, key) -> Any: + # requires to be overridden as attributes of the wrapped loop are being accessed. + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + + +############################################################################################# +# Step 5 / 5: Connect the KFoldLoop to the Trainer # +# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to # +# the Trainer. # +# Finally, use `trainer.fit` to start the cross validation training. # +############################################################################################# + +model = LitClassifier() +datamodule = MNISTKFoldDataModule() +trainer = Trainer( + max_epochs=10, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + num_sanity_val_steps=0, + devices=1, + accelerator="auto", + strategy="ddp", +) +trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./") +trainer.fit(model, datamodule) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2990792502..ac20014d49 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1225,7 +1225,8 @@ class Trainer( # reload data when needed model = self.lightning_module - self.reset_train_val_dataloaders(model) + if isinstance(self.fit_loop, FitLoop): + self.reset_train_val_dataloaders(model) self.fit_loop.trainer = self with torch.autograd.set_detect_anomaly(self._detect_anomaly):