From b7afac351b61a1f90e9b0611e267731058c8cda0 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 31 Jul 2020 15:57:57 +0530 Subject: [PATCH] Add onnx export (#2596) * export model to onnx * prepare data before exporting * support for dataloaders and tensors * added tests * use example_input_array add to changelog * updated docstring * added onnx inference tests * temp commit * removed schema valid test * add onnxruntime to environment.yml * moved onnxruntime to environment.yml pip * add example in doc * add lines between code block * added PR to changelog * is file check Co-authored-by: Jirka Borovec * remove * Co-authored-by: Jirka Borovec * infer example outputs * added doctest for onnx * fix windows tests * moved eval within condition block * self.forward to self * added docs * fixed docs error * added to toctree * Update CHANGELOG.md Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 + docs/source/index.rst | 1 + docs/source/production_inference.rst | 28 +++++++ environment.yml | 1 + pytorch_lightning/core/lightning.py | 39 +++++++++ requirements/extra.txt | 2 + tests/base/model_template.py | 4 +- tests/models/test_onnx_save.py | 114 +++++++++++++++++++++++++++ 8 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 docs/source/production_inference.rst create mode 100644 tests/models/test_onnx_save.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a4cd39873d..ed0880c345 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671)) - Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535)) +- Added support to export a model to ONNX format ([#2596](https://github.com/PyTorchLightning/pytorch-lightning/pull/2596)) + - Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246)) - Added support for PyTorch 1.6 ([#2745](https://github.com/PyTorchLightning/pytorch-lightning/pull/2745)) diff --git a/docs/source/index.rst b/docs/source/index.rst index 4b1b7c697a..3637892848 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,6 +99,7 @@ PyTorch Lightning Documentation transfer_learning tpu test_set + production_inference .. toctree:: :maxdepth: 1 diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst new file mode 100644 index 0000000000..3159abe630 --- /dev/null +++ b/docs/source/production_inference.rst @@ -0,0 +1,28 @@ +Inference in Production +======================= +PyTorch Lightning eases the process of deploying models into production. + + +Exporting to ONNX +----------------- +PyTorch Lightning provides a handy function to quickly export your model to ONNX format, which allows the model to be independent of PyTorch and run on an ONNX Runtime. + +To export your model to ONNX format call the `to_onnx` function on your Lightning Module with the filepath and input_sample. + +.. code-block:: python + + filepath = 'model.onnx' + model = SimpleModel() + input_sample = torch.randn((1, 64)) + model.to_onnx(filepath, input_sample, export_params=True) + +You can also skip passing the input sample if the `example_input_array` property is specified in your LightningModule. + +Once you have the exported model, you can run it on your ONNX runtime in the following way: + +.. code-block:: python + + ort_session = onnxruntime.InferenceSession(filepath) + input_name = ort_session.get_inputs()[0].name + ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)} + ort_outs = ort_session.run(None, ort_inputs) diff --git a/environment.yml b/environment.yml index 9c48f6d7e2..07afe80557 100644 --- a/environment.yml +++ b/environment.yml @@ -48,3 +48,4 @@ dependencies: - wandb>=0.8.21 - neptune-client>=0.4.109 - horovod>=0.19.1 + - onnxruntime>=1.3.0 diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index afb9fa0a92..5ff64156e1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -2,6 +2,7 @@ import collections import inspect import os import re +import tempfile from abc import ABC, abstractmethod from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -1723,6 +1724,44 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod else: self._hparams = hp + def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs): + """Saves the model in ONNX format + + Args: + file_path: The path of the file the model should be saved to. + input_sample: A sample of an input tensor for tracing. + **kwargs: Will be passed to torch.onnx.export function. + + Example: + >>> class SimpleModel(LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.l1 = torch.nn.Linear(in_features=64, out_features=4) + ... + ... def forward(self, x): + ... return torch.relu(self.l1(x.view(x.size(0), -1))) + + >>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile: + ... model = SimpleModel() + ... input_sample = torch.randn((1, 64)) + ... model.to_onnx(tmpfile.name, input_sample, export_params=True) + ... os.path.isfile(tmpfile.name) + True + """ + + if isinstance(input_sample, Tensor): + input_data = input_sample + elif self.example_input_array is not None: + input_data = self.example_input_array + else: + raise ValueError(f'input_sample and example_input_array tensors are both missing.') + + if 'example_outputs' not in kwargs: + self.eval() + kwargs['example_outputs'] = self(input_data) + + torch.onnx.export(self, input_data, file_path, **kwargs) + @property def hparams(self) -> Union[AttributeDict, str]: if not hasattr(self, '_hparams'): diff --git a/requirements/extra.txt b/requirements/extra.txt index 191d24125a..31ea41c083 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -12,3 +12,5 @@ omegaconf>=2.0.0 # scipy>=0.13.3 scikit-learn>=0.20.0 torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility +onnx>=1.7.0 +onnxruntime>=1.3.0 \ No newline at end of file diff --git a/tests/base/model_template.py b/tests/base/model_template.py index f529ce5735..19fcd42195 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -73,9 +73,7 @@ class EvalModelTemplate( self.test_step_end_called = False self.test_epoch_end_called = False - # if you specify an example input, the summary will show input/output for each layer - # TODO: to be fixed in #1773 - # self.example_input_array = torch.rand(5, 28 * 28) + self.example_input_array = torch.rand(5, 28 * 28) # build model self.__build_model() diff --git a/tests/models/test_onnx_save.py b/tests/models/test_onnx_save.py new file mode 100644 index 0000000000..f824f33c93 --- /dev/null +++ b/tests/models/test_onnx_save.py @@ -0,0 +1,114 @@ +import os + +import onnxruntime +import pytest +import torch +import numpy as np +import tests.base.develop_pipelines as tpipes +import tests.base.develop_utils as tutils +from pytorch_lightning import Trainer +from tests.base import EvalModelTemplate + + +def test_model_saves_with_input_sample(tmpdir): + """Test that ONNX model saves with input sample and size is greater than 3 MB""" + model = EvalModelTemplate() + trainer = Trainer(max_epochs=1) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onxx") + input_sample = torch.randn((1, 28 * 28)) + model.to_onnx(file_path, input_sample) + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 3e+06 + + +def test_model_saves_with_example_output(tmpdir): + """Test that ONNX model saves when provided with example output""" + model = EvalModelTemplate() + trainer = Trainer(max_epochs=1) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onxx") + input_sample = torch.randn((1, 28 * 28)) + model.eval() + example_outputs = model.forward(input_sample) + model.to_onnx(file_path, input_sample, example_outputs=example_outputs) + assert os.path.exists(file_path) is True + + +def test_model_saves_with_example_input_array(tmpdir): + """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" + model = EvalModelTemplate() + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path) + assert os.path.exists(file_path) is True + assert os.path.getsize(file_path) > 3e+06 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_model_saves_on_multi_gpu(tmpdir): + """Test that ONNX model saves on a distributed backend""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + progress_bar_refresh_rate=0 + ) + + model = EvalModelTemplate() + + tpipes.run_model_test(trainer_options, model) + + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path) + assert os.path.exists(file_path) is True + + +def test_verbose_param(tmpdir, capsys): + """Test that output is present when verbose parameter is set""" + model = EvalModelTemplate() + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path, verbose=True) + captured = capsys.readouterr() + assert "graph(%" in captured.out + + +def test_error_if_no_input(tmpdir): + """Test that an exception is thrown when there is no input tensor""" + model = EvalModelTemplate() + model.example_input_array = None + file_path = os.path.join(tmpdir, "model.onxx") + with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'): + model.to_onnx(file_path) + + +def test_if_inference_output_is_valid(tmpdir): + """Test that the output inferred from ONNX model is same as from PyTorch""" + model = EvalModelTemplate() + trainer = Trainer(max_epochs=5) + trainer.fit(model) + + model.eval() + with torch.no_grad(): + torch_out = model(model.example_input_array) + + file_path = os.path.join(tmpdir, "model.onxx") + model.to_onnx(file_path, model.example_input_array, export_params=True) + + ort_session = onnxruntime.InferenceSession(file_path) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + # compute ONNX Runtime output prediction + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)} + ort_outs = ort_session.run(None, ort_inputs) + + # compare ONNX Runtime and PyTorch results + assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)