lightning/examples/fabric/kfold_cv
Jirka Borovec fe535970a9
lint: switch `pyupgrade` with Ruff's `UP` rule (#19638)
* ruff
* configure
* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-03-15 20:48:17 +01:00
..
README.md Rename `fabric run model` to `fabric run` (#19527) 2024-02-27 11:36:46 -05:00
train_fabric.py lint: switch `pyupgrade` with Ruff's `UP` rule (#19638) 2024-03-15 20:48:17 +01:00

README.md

K-Fold Cross Validation

This is an example of performing K-Fold cross validation supported with Lightning Fabric. To learn more about cross validation, check out this article.

We use the MNIST dataset to train a simple CNN model. We create the k-fold cross validation splits using the ModelSelection.KFold class in the scikit-learn library. Ensure that you have the scikit-learn library installed;

pip install scikit-learn

Run K-Fold Image Classification with Lightning Fabric

This script shows you how to scale the pure PyTorch code to enable GPU and multi-GPU training using Lightning Fabric.

# CPU
fabric run train_fabric.py

# GPU (CUDA or M1 Mac)
fabric run train_fabric.py --accelerator=gpu

# Multiple GPUs
fabric run train_fabric.py --accelerator=gpu --devices=4

References