From 6e5f232f5cec2b5e635ae34fa365c6b969d0902e Mon Sep 17 00:00:00 2001 From: Indrayana Rustandi <468296+irustandi@users.noreply.github.com> Date: Fri, 6 Nov 2020 09:53:46 -0500 Subject: [PATCH] Add Dali MNIST example (#3721) * add MNIST DALI example, update README.md * Fix PEP8 warnings * reformatted using black * add mnist_dali to test_examples.py * Add documentation as docstrings * add nvidia-pyindex and nvidia-dali-cuda100 * replace nvidia-pyindex with --extra-index-url * mark mnist_dali test as Linux and GPU only * adjust CUDA docker and examples.txt, fix import error in test_examples.py * adjust the GPU check * Exit when DALI is not available * remove requirements-examples.txt and DALI pip install * Refactored example, moved to new logging api, added runtime check for test and dali script * Patch to reflect the mnist example module * add req. * Apply suggestions from code review * Removed requirement as it breaks CPU install, added note in README to install DALI * add DALI to Drone * test examples * Apply suggestions from code review * imports * ABC * cuda * cuda * pip DALI * Move build into init function Co-authored-by: SeanNaren Co-authored-by: Jirka Borovec Co-authored-by: Jirka Borovec Co-authored-by: Sean Naren --- .drone.yml | 2 + pl_examples/basic_examples/README.md | 10 +- pl_examples/basic_examples/mnist_dali.py | 204 +++++++++++++++++++++++ pl_examples/test_examples.py | 41 ++++- requirements/examples.txt | 2 +- 5 files changed, 249 insertions(+), 10 deletions(-) create mode 100644 pl_examples/basic_examples/mnist_dali.py diff --git a/.drone.yml b/.drone.yml index 5e6c08f7a8..9774ffaaae 100644 --- a/.drone.yml +++ b/.drone.yml @@ -32,6 +32,8 @@ steps: - pip --version - nvidia-smi - pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir + # when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0" + - pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed - pip list - coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --color=yes --durations=25 # --flake8 - python -m pytest benchmarks pl_examples -v --color=yes --maxfail=2 --durations=0 # --flake8 diff --git a/pl_examples/basic_examples/README.md b/pl_examples/basic_examples/README.md index 4dcf06a74b..18ae204396 100644 --- a/pl_examples/basic_examples/README.md +++ b/pl_examples/basic_examples/README.md @@ -14,7 +14,15 @@ python mnist.py python mnist.py --gpus 2 --distributed_backend 'dp' ``` ---- +--- +#### MNIST with DALI +The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI). +Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html). +```bash +python mnist_dali.py +``` + +--- #### Image classifier Generic image classifier with an arbitrary backbone (ie: a simple system) ```bash diff --git a/pl_examples/basic_examples/mnist_dali.py b/pl_examples/basic_examples/mnist_dali.py new file mode 100644 index 0000000000..649198053a --- /dev/null +++ b/pl_examples/basic_examples/mnist_dali.py @@ -0,0 +1,204 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from argparse import ArgumentParser +from random import shuffle +from warnings import warn + +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data import random_split + +import pytorch_lightning as pl + +try: + from torchvision.datasets.mnist import MNIST + from torchvision import transforms +except Exception: + from tests.base.datasets import MNIST + +try: + import nvidia.dali.ops as ops + import nvidia.dali.types as types + from nvidia.dali.pipeline import Pipeline + from nvidia.dali.plugin.pytorch import DALIClassificationIterator +except (ImportError, ModuleNotFoundError): + warn('NVIDIA DALI is not available') + ops, types, Pipeline, DALIClassificationIterator = ..., ..., ABC, ABC + + +class ExternalMNISTInputIterator(object): + """ + This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches + """ + + def __init__(self, mnist_ds, batch_size): + self.batch_size = batch_size + self.mnist_ds = mnist_ds + self.indices = list(range(len(self.mnist_ds))) + shuffle(self.indices) + + def __iter__(self): + self.i = 0 + self.n = len(self.mnist_ds) + return self + + def __next__(self): + batch = [] + labels = [] + for _ in range(self.batch_size): + index = self.indices[self.i] + img, label = self.mnist_ds[index] + batch.append(img.numpy()) + labels.append(np.array([label], dtype=np.uint8)) + self.i = (self.i + 1) % self.n + return (batch, labels) + + +class ExternalSourcePipeline(Pipeline): + """ + This DALI pipeline class just contains the MNIST iterator + """ + + def __init__(self, batch_size, eii, num_threads, device_id): + super(ExternalSourcePipeline, self).__init__(batch_size, num_threads, device_id, seed=12) + self.source = ops.ExternalSource(source=eii, num_outputs=2) + self.build() + + def define_graph(self): + images, labels = self.source() + return images, labels + + +class DALIClassificationLoader(DALIClassificationIterator): + """ + This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it + """ + + def __init__( + self, + pipelines, + size=-1, + reader_name=None, + auto_reset=False, + fill_last_batch=True, + dynamic_shape=False, + last_batch_padded=False, + ): + super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded) + + def __len__(self): + batch_count = self._size // (self._num_gpus * self.batch_size) + last_batch = 1 if self._fill_last_batch else 0 + return batch_count + last_batch + + +class LitClassifier(pl.LightningModule): + def __init__(self, hidden_dim=128, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + + self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) + self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + def split_batch(self, batch): + return batch[0]["data"], batch[0]["label"].squeeze().long() + + def training_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss) + + def test_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--hidden_dim', type=int, default=128) + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parser + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser = pl.Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size) + eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size) + eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size) + + pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0) + train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False) + + pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0) + val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False) + + pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0) + test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False) + + # ------------ + # model + # ------------ + model = LitClassifier(args.hidden_dim, args.learning_rate) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + trainer.test(test_dataloaders=test_loader) + + +if __name__ == "__main__": + cli_main() diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py index 7fe5d4ed60..60f10a637e 100644 --- a/pl_examples/test_examples.py +++ b/pl_examples/test_examples.py @@ -1,6 +1,15 @@ +import platform from unittest import mock -import torch + import pytest +import torch + +try: + from nvidia.dali import ops, types, pipeline, plugin +except (ImportError, ModuleNotFoundError): + DALI_AVAILABLE = False +else: + DALI_AVAILABLE = True dp_16_args = """ --max_epochs 1 \ @@ -28,7 +37,7 @@ ddp_args = """ --precision 16 \ """ - +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [dp_16_args]) # def test_examples_dp_mnist(cli_args): @@ -38,6 +47,7 @@ ddp_args = """ # cli_main() +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [dp_16_args]) # def test_examples_dp_image_classifier(cli_args): @@ -45,8 +55,9 @@ ddp_args = """ # # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): # cli_main() -# -# + + +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [dp_16_args]) # def test_examples_dp_autoencoder(cli_args): @@ -56,6 +67,7 @@ ddp_args = """ # cli_main() +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [ddp_args]) # def test_examples_ddp_mnist(cli_args): @@ -63,8 +75,9 @@ ddp_args = """ # # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): # cli_main() -# -# + + +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [ddp_args]) # def test_examples_ddp_image_classifier(cli_args): @@ -72,8 +85,9 @@ ddp_args = """ # # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): # cli_main() -# -# + + +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [ddp_args]) # def test_examples_ddp_autoencoder(cli_args): @@ -92,3 +106,14 @@ def test_examples_cpu(cli_args): for cli_cmd in [mnist_cli, ic_cli, ae_cli]: with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): cli_cmd() + + +@pytest.mark.skipif(not DALI_AVAILABLE, reason="Nvidia DALI required") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(platform.system() != 'Linux', reason='Only applies to Linux platform.') +@pytest.mark.parametrize('cli_args', [cpu_args]) +def test_examples_mnist_dali(cli_args): + from pl_examples.basic_examples.mnist_dali import cli_main + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + cli_main() diff --git a/requirements/examples.txt b/requirements/examples.txt index e930579b8b..0afa62f9ff 100644 --- a/requirements/examples.txt +++ b/requirements/examples.txt @@ -1,2 +1,2 @@ torchvision>=0.4.1,<0.9.0 -gym>=0.17.0 \ No newline at end of file +gym>=0.17.0