fe535970a9
* ruff * configure * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> |
||
---|---|---|
.. | ||
README.md | ||
train_fabric.py |
README.md
K-Fold Cross Validation
This is an example of performing K-Fold cross validation supported with Lightning Fabric. To learn more about cross validation, check out this article.
We use the MNIST dataset to train a simple CNN model. We create the k-fold cross validation splits using the ModelSelection.KFold
class in the scikit-learn
library. Ensure that you have the scikit-learn
library installed;
pip install scikit-learn
Run K-Fold Image Classification with Lightning Fabric
This script shows you how to scale the pure PyTorch code to enable GPU and multi-GPU training using Lightning Fabric.
# CPU
fabric run train_fabric.py
# GPU (CUDA or M1 Mac)
fabric run train_fabric.py --accelerator=gpu
# Multiple GPUs
fabric run train_fabric.py --accelerator=gpu --devices=4