parent
35876bb75f
commit
e2ecb8f859
|
@ -315,6 +315,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645))
|
||||
|
||||
|
||||
- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800))
|
||||
|
||||
|
||||
## [1.4.0] - 2021-07-27
|
||||
|
||||
### Added
|
||||
|
|
|
@ -1843,6 +1843,9 @@ class LightningModule(
|
|||
|
||||
if "example_outputs" not in kwargs:
|
||||
self.eval()
|
||||
if isinstance(input_sample, Tuple):
|
||||
kwargs["example_outputs"] = self(*input_sample)
|
||||
else:
|
||||
kwargs["example_outputs"] = self(input_sample)
|
||||
|
||||
torch.onnx.export(self, input_sample, file_path, **kwargs)
|
||||
|
|
|
@ -23,6 +23,7 @@ import tests.helpers.utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
from tests.utilities.test_model_summary import UnorderedModel
|
||||
|
||||
|
||||
def test_model_saves_with_input_sample(tmpdir):
|
||||
|
@ -66,10 +67,17 @@ def test_model_saves_with_example_output(tmpdir):
|
|||
assert os.path.exists(file_path) is True
|
||||
|
||||
|
||||
def test_model_saves_with_example_input_array(tmpdir):
|
||||
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
|
||||
model = BoringModel()
|
||||
model.example_input_array = torch.randn(5, 32)
|
||||
@pytest.mark.parametrize(
|
||||
["modelclass", "input_sample"],
|
||||
[
|
||||
(BoringModel, torch.randn(1, 32)),
|
||||
(UnorderedModel, (torch.rand(2, 3), torch.rand(2, 10))),
|
||||
],
|
||||
)
|
||||
def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample):
|
||||
"""Test that ONNX model saves with example_input_array and size is greater than 3 MB"""
|
||||
model = modelclass()
|
||||
model.example_input_array = input_sample
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
model.to_onnx(file_path)
|
||||
|
|
Loading…
Reference in New Issue