Allow exporting to onnx when input is tuple (#8800)

Fixes #8799
This commit is contained in:
Pavel Grunt 2021-09-02 03:36:20 +02:00 committed by GitHub
parent 35876bb75f
commit e2ecb8f859
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 5 deletions

View File

@ -315,6 +315,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645))
- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800))
## [1.4.0] - 2021-07-27
### Added

View File

@ -1843,7 +1843,10 @@ class LightningModule(
if "example_outputs" not in kwargs:
self.eval()
kwargs["example_outputs"] = self(input_sample)
if isinstance(input_sample, Tuple):
kwargs["example_outputs"] = self(*input_sample)
else:
kwargs["example_outputs"] = self(input_sample)
torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)

View File

@ -23,6 +23,7 @@ import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
from tests.utilities.test_model_summary import UnorderedModel
def test_model_saves_with_input_sample(tmpdir):
@ -66,10 +67,17 @@ def test_model_saves_with_example_output(tmpdir):
assert os.path.exists(file_path) is True
def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
@pytest.mark.parametrize(
["modelclass", "input_sample"],
[
(BoringModel, torch.randn(1, 32)),
(UnorderedModel, (torch.rand(2, 3), torch.rand(2, 10))),
],
)
def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample):
"""Test that ONNX model saves with example_input_array and size is greater than 3 MB"""
model = modelclass()
model.example_input_array = input_sample
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)