diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bda9ff766..d86f5ac2e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -315,6 +315,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645)) +- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800)) + + ## [1.4.0] - 2021-07-27 ### Added diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7e78a40bc5..5954af1c67 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1843,7 +1843,10 @@ class LightningModule( if "example_outputs" not in kwargs: self.eval() - kwargs["example_outputs"] = self(input_sample) + if isinstance(input_sample, Tuple): + kwargs["example_outputs"] = self(*input_sample) + else: + kwargs["example_outputs"] = self(input_sample) torch.onnx.export(self, input_sample, file_path, **kwargs) self.train(mode) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index cec01e828d..7cd1d2776f 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -23,6 +23,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from tests.helpers import BoringModel from tests.helpers.runif import RunIf +from tests.utilities.test_model_summary import UnorderedModel def test_model_saves_with_input_sample(tmpdir): @@ -66,10 +67,17 @@ def test_model_saves_with_example_output(tmpdir): assert os.path.exists(file_path) is True -def test_model_saves_with_example_input_array(tmpdir): - """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" - model = BoringModel() - model.example_input_array = torch.randn(5, 32) +@pytest.mark.parametrize( + ["modelclass", "input_sample"], + [ + (BoringModel, torch.randn(1, 32)), + (UnorderedModel, (torch.rand(2, 3), torch.rand(2, 10))), + ], +) +def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample): + """Test that ONNX model saves with example_input_array and size is greater than 3 MB""" + model = modelclass() + model.example_input_array = input_sample file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path)