543 lines
20 KiB
Plaintext
543 lines
20 KiB
Plaintext
|
{
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"name": "02-datamodules.ipynb",
|
||
|
"provenance": [],
|
||
|
"collapsed_sections": [],
|
||
|
"toc_visible": true,
|
||
|
"include_colab_link": true
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"name": "python3",
|
||
|
"display_name": "Python 3"
|
||
|
},
|
||
|
"accelerator": "GPU"
|
||
|
},
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "view-in-github",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "2O5r7QvP8-rt",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"# PyTorch Lightning DataModules ⚡\n",
|
||
|
"\n",
|
||
|
"With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n",
|
||
|
"\n",
|
||
|
"This notebook will walk you through how to start using Datamodules.\n",
|
||
|
"\n",
|
||
|
"The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).\n",
|
||
|
"\n",
|
||
|
"---\n",
|
||
|
"\n",
|
||
|
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
||
|
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
||
|
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "6RYMhmfA9ATN",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"### Setup\n",
|
||
|
"Lightning is easy to install. Simply ```pip install pytorch-lightning```"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "lj2zD-wsbvGr",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"! pip install pytorch-lightning --quiet"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "8g2mbvy-9xDI",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Introduction\n",
|
||
|
"\n",
|
||
|
"First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "eg-xDlmDdAwy",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"import pytorch_lightning as pl\n",
|
||
|
"from pytorch_lightning.metrics.functional import accuracy\n",
|
||
|
"import torch\n",
|
||
|
"from torch import nn\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"from torch.utils.data import random_split, DataLoader\n",
|
||
|
"\n",
|
||
|
"# Note - you must have torchvision installed for this example\n",
|
||
|
"from torchvision.datasets import MNIST, CIFAR10\n",
|
||
|
"from torchvision import transforms"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "DzgY7wi88UuG",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Defining the LitMNISTModel\n",
|
||
|
"\n",
|
||
|
"Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n",
|
||
|
"\n",
|
||
|
"Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n",
|
||
|
"\n",
|
||
|
"This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "IQkW8_FF5nU2",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"class LitMNIST(pl.LightningModule):\n",
|
||
|
" \n",
|
||
|
" def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
|
||
|
"\n",
|
||
|
" super().__init__()\n",
|
||
|
"\n",
|
||
|
" # We hardcode dataset specific stuff here.\n",
|
||
|
" self.data_dir = data_dir\n",
|
||
|
" self.num_classes = 10\n",
|
||
|
" self.dims = (1, 28, 28)\n",
|
||
|
" channels, width, height = self.dims\n",
|
||
|
" self.transform = transforms.Compose([\n",
|
||
|
" transforms.ToTensor(),\n",
|
||
|
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||
|
" ])\n",
|
||
|
"\n",
|
||
|
" self.hidden_size = hidden_size\n",
|
||
|
" self.learning_rate = learning_rate\n",
|
||
|
"\n",
|
||
|
" # Build model\n",
|
||
|
" self.model = nn.Sequential(\n",
|
||
|
" nn.Flatten(),\n",
|
||
|
" nn.Linear(channels * width * height, hidden_size),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Dropout(0.1),\n",
|
||
|
" nn.Linear(hidden_size, hidden_size),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Dropout(0.1),\n",
|
||
|
" nn.Linear(hidden_size, self.num_classes)\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.model(x)\n",
|
||
|
" return F.log_softmax(x, dim=1)\n",
|
||
|
"\n",
|
||
|
" def training_step(self, batch, batch_idx):\n",
|
||
|
" x, y = batch\n",
|
||
|
" logits = self(x)\n",
|
||
|
" loss = F.nll_loss(logits, y)\n",
|
||
|
" return pl.TrainResult(loss)\n",
|
||
|
"\n",
|
||
|
" def validation_step(self, batch, batch_idx):\n",
|
||
|
" x, y = batch\n",
|
||
|
" logits = self(x)\n",
|
||
|
" loss = F.nll_loss(logits, y)\n",
|
||
|
" preds = torch.argmax(logits, dim=1)\n",
|
||
|
" acc = accuracy(preds, y)\n",
|
||
|
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
||
|
" result.log('val_loss', loss, prog_bar=True)\n",
|
||
|
" result.log('val_acc', acc, prog_bar=True)\n",
|
||
|
" return result\n",
|
||
|
"\n",
|
||
|
" def configure_optimizers(self):\n",
|
||
|
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
|
||
|
" return optimizer\n",
|
||
|
"\n",
|
||
|
" ####################\n",
|
||
|
" # DATA RELATED HOOKS\n",
|
||
|
" ####################\n",
|
||
|
"\n",
|
||
|
" def prepare_data(self):\n",
|
||
|
" # download\n",
|
||
|
" MNIST(self.data_dir, train=True, download=True)\n",
|
||
|
" MNIST(self.data_dir, train=False, download=True)\n",
|
||
|
"\n",
|
||
|
" def setup(self, stage=None):\n",
|
||
|
"\n",
|
||
|
" # Assign train/val datasets for use in dataloaders\n",
|
||
|
" if stage == 'fit' or stage is None:\n",
|
||
|
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||
|
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||
|
"\n",
|
||
|
" # Assign test dataset for use in dataloader(s)\n",
|
||
|
" if stage == 'test' or stage is None:\n",
|
||
|
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||
|
"\n",
|
||
|
" def train_dataloader(self):\n",
|
||
|
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
||
|
"\n",
|
||
|
" def val_dataloader(self):\n",
|
||
|
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
||
|
"\n",
|
||
|
" def test_dataloader(self):\n",
|
||
|
" return DataLoader(self.mnist_test, batch_size=32)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "K7sg9KQd-QIO",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Training the ListMNIST Model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "QxDNDaus6byD",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"model = LitMNIST()\n",
|
||
|
"trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)\n",
|
||
|
"trainer.fit(model)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "dY8d6GxmB0YU",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Using DataModules\n",
|
||
|
"\n",
|
||
|
"DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "eJeT5bW081wn",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Defining The MNISTDataModule\n",
|
||
|
"\n",
|
||
|
"Let's go over each function in the class below and talk about what they're doing:\n",
|
||
|
"\n",
|
||
|
"1. ```__init__```\n",
|
||
|
" - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n",
|
||
|
" - Defines a transform that will be applied across train, val, and test dataset splits.\n",
|
||
|
" - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"2. ```prepare_data```\n",
|
||
|
" - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
|
||
|
" - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
|
||
|
"\n",
|
||
|
"3. ```setup```\n",
|
||
|
" - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
|
||
|
" - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
|
||
|
" - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n",
|
||
|
" - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"4. ```x_dataloader```\n",
|
||
|
" - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "DfGKyGwG_X9v",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"class MNISTDataModule(pl.LightningDataModule):\n",
|
||
|
"\n",
|
||
|
" def __init__(self, data_dir: str = './'):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.data_dir = data_dir\n",
|
||
|
" self.transform = transforms.Compose([\n",
|
||
|
" transforms.ToTensor(),\n",
|
||
|
" transforms.Normalize((0.1307,), (0.3081,))\n",
|
||
|
" ])\n",
|
||
|
"\n",
|
||
|
" # self.dims is returned when you call dm.size()\n",
|
||
|
" # Setting default dims here because we know them.\n",
|
||
|
" # Could optionally be assigned dynamically in dm.setup()\n",
|
||
|
" self.dims = (1, 28, 28)\n",
|
||
|
" self.num_classes = 10\n",
|
||
|
"\n",
|
||
|
" def prepare_data(self):\n",
|
||
|
" # download\n",
|
||
|
" MNIST(self.data_dir, train=True, download=True)\n",
|
||
|
" MNIST(self.data_dir, train=False, download=True)\n",
|
||
|
"\n",
|
||
|
" def setup(self, stage=None):\n",
|
||
|
"\n",
|
||
|
" # Assign train/val datasets for use in dataloaders\n",
|
||
|
" if stage == 'fit' or stage is None:\n",
|
||
|
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
|
||
|
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
|
||
|
"\n",
|
||
|
" # Assign test dataset for use in dataloader(s)\n",
|
||
|
" if stage == 'test' or stage is None:\n",
|
||
|
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
|
||
|
"\n",
|
||
|
" def train_dataloader(self):\n",
|
||
|
" return DataLoader(self.mnist_train, batch_size=32)\n",
|
||
|
"\n",
|
||
|
" def val_dataloader(self):\n",
|
||
|
" return DataLoader(self.mnist_val, batch_size=32)\n",
|
||
|
"\n",
|
||
|
" def test_dataloader(self):\n",
|
||
|
" return DataLoader(self.mnist_test, batch_size=32)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "H2Yoj-9M9dS7",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Defining the dataset agnostic `LitModel`\n",
|
||
|
"\n",
|
||
|
"Below, we define the same model as the `LitMNIST` model we made earlier. \n",
|
||
|
"\n",
|
||
|
"However, this time our model has the freedom to use any input data that we'd like 🔥."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "PM2IISuOBDIu",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"class LitModel(pl.LightningModule):\n",
|
||
|
" \n",
|
||
|
" def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
|
||
|
"\n",
|
||
|
" super().__init__()\n",
|
||
|
"\n",
|
||
|
" # We take in input dimensions as parameters and use those to dynamically build model.\n",
|
||
|
" self.channels = channels\n",
|
||
|
" self.width = width\n",
|
||
|
" self.height = height\n",
|
||
|
" self.num_classes = num_classes\n",
|
||
|
" self.hidden_size = hidden_size\n",
|
||
|
" self.learning_rate = learning_rate\n",
|
||
|
"\n",
|
||
|
" self.model = nn.Sequential(\n",
|
||
|
" nn.Flatten(),\n",
|
||
|
" nn.Linear(channels * width * height, hidden_size),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Dropout(0.1),\n",
|
||
|
" nn.Linear(hidden_size, hidden_size),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Dropout(0.1),\n",
|
||
|
" nn.Linear(hidden_size, num_classes)\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.model(x)\n",
|
||
|
" return F.log_softmax(x, dim=1)\n",
|
||
|
"\n",
|
||
|
" def training_step(self, batch, batch_idx):\n",
|
||
|
" x, y = batch\n",
|
||
|
" logits = self(x)\n",
|
||
|
" loss = F.nll_loss(logits, y)\n",
|
||
|
" return pl.TrainResult(loss)\n",
|
||
|
"\n",
|
||
|
" def validation_step(self, batch, batch_idx):\n",
|
||
|
"\n",
|
||
|
" x, y = batch\n",
|
||
|
" logits = self(x)\n",
|
||
|
" loss = F.nll_loss(logits, y)\n",
|
||
|
" preds = torch.argmax(logits, dim=1)\n",
|
||
|
" acc = accuracy(preds, y)\n",
|
||
|
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
||
|
" result.log('val_loss', loss, prog_bar=True)\n",
|
||
|
" result.log('val_acc', acc, prog_bar=True)\n",
|
||
|
" return result\n",
|
||
|
"\n",
|
||
|
" def configure_optimizers(self):\n",
|
||
|
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
|
||
|
" return optimizer"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "G4Z5olPe-xEo",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Training the `LitModel` using the `MNISTDataModule`\n",
|
||
|
"\n",
|
||
|
"Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "kV48vP_9mEli",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"# Init DataModule\n",
|
||
|
"dm = MNISTDataModule()\n",
|
||
|
"# Init model from datamodule's attributes\n",
|
||
|
"model = LitModel(*dm.size(), dm.num_classes)\n",
|
||
|
"# Init trainer\n",
|
||
|
"trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)\n",
|
||
|
"# Pass the datamodule as arg to trainer.fit to override model hooks :)\n",
|
||
|
"trainer.fit(model, dm)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "WNxrugIGRRv5",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Defining the CIFAR10 DataModule\n",
|
||
|
"\n",
|
||
|
"Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "1tkaYLU7RT5P",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"class CIFAR10DataModule(pl.LightningDataModule):\n",
|
||
|
"\n",
|
||
|
" def __init__(self, data_dir: str = './'):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.data_dir = data_dir\n",
|
||
|
" self.transform = transforms.Compose([\n",
|
||
|
" transforms.ToTensor(),\n",
|
||
|
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
|
||
|
" ])\n",
|
||
|
"\n",
|
||
|
" self.dims = (3, 32, 32)\n",
|
||
|
" self.num_classes = 10\n",
|
||
|
"\n",
|
||
|
" def prepare_data(self):\n",
|
||
|
" # download\n",
|
||
|
" CIFAR10(self.data_dir, train=True, download=True)\n",
|
||
|
" CIFAR10(self.data_dir, train=False, download=True)\n",
|
||
|
"\n",
|
||
|
" def setup(self, stage=None):\n",
|
||
|
"\n",
|
||
|
" # Assign train/val datasets for use in dataloaders\n",
|
||
|
" if stage == 'fit' or stage is None:\n",
|
||
|
" cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n",
|
||
|
" self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n",
|
||
|
"\n",
|
||
|
" # Assign test dataset for use in dataloader(s)\n",
|
||
|
" if stage == 'test' or stage is None:\n",
|
||
|
" self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n",
|
||
|
"\n",
|
||
|
" def train_dataloader(self):\n",
|
||
|
" return DataLoader(self.cifar_train, batch_size=32)\n",
|
||
|
"\n",
|
||
|
" def val_dataloader(self):\n",
|
||
|
" return DataLoader(self.cifar_val, batch_size=32)\n",
|
||
|
"\n",
|
||
|
" def test_dataloader(self):\n",
|
||
|
" return DataLoader(self.cifar_test, batch_size=32)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "BrXxf3oX_gsZ",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Training the `LitModel` using the `CIFAR10DataModule`\n",
|
||
|
"\n",
|
||
|
"Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n",
|
||
|
"\n",
|
||
|
"The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "sd-SbWi_krdj",
|
||
|
"colab_type": "code",
|
||
|
"colab": {}
|
||
|
},
|
||
|
"source": [
|
||
|
"dm = CIFAR10DataModule()\n",
|
||
|
"model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)\n",
|
||
|
"trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)\n",
|
||
|
"trainer.fit(model, dm)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
}
|
||
|
]
|
||
|
}
|