Add onnx export (#2596)
* export model to onnx * prepare data before exporting * support for dataloaders and tensors * added tests * use example_input_array add to changelog * updated docstring * added onnx inference tests * temp commit * removed schema valid test * add onnxruntime to environment.yml * moved onnxruntime to environment.yml pip * add example in doc * add lines between code block * added PR to changelog * is file check Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * remove * Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * infer example outputs * added doctest for onnx * fix windows tests * moved eval within condition block * self.forward to self * added docs * fixed docs error * added to toctree * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
06e8910f06
commit
b7afac351b
|
@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
|
||||
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
|
||||
|
||||
- Added support to export a model to ONNX format ([#2596](https://github.com/PyTorchLightning/pytorch-lightning/pull/2596))
|
||||
|
||||
- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))
|
||||
|
||||
- Added support for PyTorch 1.6 ([#2745](https://github.com/PyTorchLightning/pytorch-lightning/pull/2745))
|
||||
|
|
|
@ -99,6 +99,7 @@ PyTorch Lightning Documentation
|
|||
transfer_learning
|
||||
tpu
|
||||
test_set
|
||||
production_inference
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
Inference in Production
|
||||
=======================
|
||||
PyTorch Lightning eases the process of deploying models into production.
|
||||
|
||||
|
||||
Exporting to ONNX
|
||||
-----------------
|
||||
PyTorch Lightning provides a handy function to quickly export your model to ONNX format, which allows the model to be independent of PyTorch and run on an ONNX Runtime.
|
||||
|
||||
To export your model to ONNX format call the `to_onnx` function on your Lightning Module with the filepath and input_sample.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
filepath = 'model.onnx'
|
||||
model = SimpleModel()
|
||||
input_sample = torch.randn((1, 64))
|
||||
model.to_onnx(filepath, input_sample, export_params=True)
|
||||
|
||||
You can also skip passing the input sample if the `example_input_array` property is specified in your LightningModule.
|
||||
|
||||
Once you have the exported model, you can run it on your ONNX runtime in the following way:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(filepath)
|
||||
input_name = ort_session.get_inputs()[0].name
|
||||
ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
|
@ -48,3 +48,4 @@ dependencies:
|
|||
- wandb>=0.8.21
|
||||
- neptune-client>=0.4.109
|
||||
- horovod>=0.19.1
|
||||
- onnxruntime>=1.3.0
|
||||
|
|
|
@ -2,6 +2,7 @@ import collections
|
|||
import inspect
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
@ -1723,6 +1724,44 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
else:
|
||||
self._hparams = hp
|
||||
|
||||
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
|
||||
"""Saves the model in ONNX format
|
||||
|
||||
Args:
|
||||
file_path: The path of the file the model should be saved to.
|
||||
input_sample: A sample of an input tensor for tracing.
|
||||
**kwargs: Will be passed to torch.onnx.export function.
|
||||
|
||||
Example:
|
||||
>>> class SimpleModel(LightningModule):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
||||
...
|
||||
... def forward(self, x):
|
||||
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
||||
|
||||
>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
|
||||
... model = SimpleModel()
|
||||
... input_sample = torch.randn((1, 64))
|
||||
... model.to_onnx(tmpfile.name, input_sample, export_params=True)
|
||||
... os.path.isfile(tmpfile.name)
|
||||
True
|
||||
"""
|
||||
|
||||
if isinstance(input_sample, Tensor):
|
||||
input_data = input_sample
|
||||
elif self.example_input_array is not None:
|
||||
input_data = self.example_input_array
|
||||
else:
|
||||
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
|
||||
|
||||
if 'example_outputs' not in kwargs:
|
||||
self.eval()
|
||||
kwargs['example_outputs'] = self(input_data)
|
||||
|
||||
torch.onnx.export(self, input_data, file_path, **kwargs)
|
||||
|
||||
@property
|
||||
def hparams(self) -> Union[AttributeDict, str]:
|
||||
if not hasattr(self, '_hparams'):
|
||||
|
|
|
@ -12,3 +12,5 @@ omegaconf>=2.0.0
|
|||
# scipy>=0.13.3
|
||||
scikit-learn>=0.20.0
|
||||
torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
|
||||
onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
|
@ -73,9 +73,7 @@ class EvalModelTemplate(
|
|||
self.test_step_end_called = False
|
||||
self.test_epoch_end_called = False
|
||||
|
||||
# if you specify an example input, the summary will show input/output for each layer
|
||||
# TODO: to be fixed in #1773
|
||||
# self.example_input_array = torch.rand(5, 28 * 28)
|
||||
self.example_input_array = torch.rand(5, 28 * 28)
|
||||
|
||||
# build model
|
||||
self.__build_model()
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
import os
|
||||
|
||||
import onnxruntime
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
import tests.base.develop_pipelines as tpipes
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
def test_model_saves_with_input_sample(tmpdir):
|
||||
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
input_sample = torch.randn((1, 28 * 28))
|
||||
model.to_onnx(file_path, input_sample)
|
||||
assert os.path.isfile(file_path)
|
||||
assert os.path.getsize(file_path) > 3e+06
|
||||
|
||||
|
||||
def test_model_saves_with_example_output(tmpdir):
|
||||
"""Test that ONNX model saves when provided with example output"""
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
input_sample = torch.randn((1, 28 * 28))
|
||||
model.eval()
|
||||
example_outputs = model.forward(input_sample)
|
||||
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
|
||||
assert os.path.exists(file_path) is True
|
||||
|
||||
|
||||
def test_model_saves_with_example_input_array(tmpdir):
|
||||
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
|
||||
model = EvalModelTemplate()
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
model.to_onnx(file_path)
|
||||
assert os.path.exists(file_path) is True
|
||||
assert os.path.getsize(file_path) > 3e+06
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_model_saves_on_multi_gpu(tmpdir):
|
||||
"""Test that ONNX model saves on a distributed backend"""
|
||||
tutils.set_random_master_port()
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=10,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp_spawn',
|
||||
progress_bar_refresh_rate=0
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
tpipes.run_model_test(trainer_options, model)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
model.to_onnx(file_path)
|
||||
assert os.path.exists(file_path) is True
|
||||
|
||||
|
||||
def test_verbose_param(tmpdir, capsys):
|
||||
"""Test that output is present when verbose parameter is set"""
|
||||
model = EvalModelTemplate()
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
model.to_onnx(file_path, verbose=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "graph(%" in captured.out
|
||||
|
||||
|
||||
def test_error_if_no_input(tmpdir):
|
||||
"""Test that an exception is thrown when there is no input tensor"""
|
||||
model = EvalModelTemplate()
|
||||
model.example_input_array = None
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
|
||||
model.to_onnx(file_path)
|
||||
|
||||
|
||||
def test_if_inference_output_is_valid(tmpdir):
|
||||
"""Test that the output inferred from ONNX model is same as from PyTorch"""
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(max_epochs=5)
|
||||
trainer.fit(model)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
torch_out = model(model.example_input_array)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onxx")
|
||||
model.to_onnx(file_path, model.example_input_array, export_params=True)
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(file_path)
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
# compute ONNX Runtime output prediction
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
# compare ONNX Runtime and PyTorch results
|
||||
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
|
Loading…
Reference in New Issue