"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i7XbLCXGkll9",
"colab_type": "text"
},
"source": [
"# Introduction to Pytorch Lightning ⚡\n",
"\n",
"In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
"\n",
"---\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2LODD6w9ixlT",
"colab_type": "text"
},
"source": [
"### Setup \n",
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
"## A more complete MNIST Lightning Module Example\n",
"\n",
"That wasn't so hard was it?\n",
"\n",
"Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
"\n",
"This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
"\n",
"---\n",
"\n",
"### Note what the following built-in functions are doing:\n",
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
"Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "PA151FkLtprO",
"colab_type": "code",
"colab": {}
},
"source": [
"trainer.test()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "T3-3lbbNtr5T",
"colab_type": "text"
},
"source": [
"### Bonus Tip\n",
"\n",
"You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "IFBwCbLet2r6",
"colab_type": "code",
"colab": {}
},
"source": [
"trainer.fit(model)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8TRyS5CCt3n9",
"colab_type": "text"
},
"source": [
"In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"