"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2O5r7QvP8-rt",
"colab_type": "text"
},
"source": [
"# PyTorch Lightning DataModules ⚡\n",
"\n",
"With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n",
"\n",
"This notebook will walk you through how to start using Datamodules.\n",
"\n",
"The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).\n",
"\n",
"---\n",
"\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6RYMhmfA9ATN",
"colab_type": "text"
},
"source": [
"### Setup\n",
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lj2zD-wsbvGr",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install pytorch-lightning --quiet"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8g2mbvy-9xDI",
"colab_type": "text"
},
"source": [
"# Introduction\n",
"\n",
"First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`"
"Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n",
"\n",
"Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n",
"\n",
"This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets."
"DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eJeT5bW081wn",
"colab_type": "text"
},
"source": [
"## Defining The MNISTDataModule\n",
"\n",
"Let's go over each function in the class below and talk about what they're doing:\n",
"\n",
"1. ```__init__```\n",
" - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n",
" - Defines a transform that will be applied across train, val, and test dataset splits.\n",
" - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n",
"\n",
"\n",
"2. ```prepare_data```\n",
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
"\n",
"3. ```setup```\n",
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n",
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
"\n",
"\n",
"4. ```x_dataloader```\n",
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"