Allow any input in to_onnx and to_torchscript (#4378)

* branch merge

* sample

* update with valid input tensors

* pep

* pathlib

* Updated with BoringModel and added more input types

* try fix

* pep

* skip test with torch < 1.4

* fix test

* Apply suggestions from code review

* update tests

* Allow any input in to_onnx and to_torchscript

* Update tests/models/test_torchscript.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* no_grad

* try fix random failing test

* rm example_input_array

* rm example_input_array

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com>
This commit is contained in:
Rohit Gupta 2020-12-12 15:47:03 +05:30 committed by GitHub
parent b5a2afd232
commit 3100b7839a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 139 additions and 77 deletions

View File

@ -14,7 +14,7 @@
"""Various hooks to be used in the Lightning code."""
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
@ -501,7 +501,7 @@ class DataHooks:
will have an argument ``dataloader_idx`` which matches the order here.
"""
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
@ -549,6 +549,7 @@ class DataHooks:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
device = device or self.device
return move_data_to_device(batch, device)

View File

@ -22,6 +22,7 @@ import re
import tempfile
from abc import ABC
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import torch
@ -1530,12 +1531,19 @@ class LightningModule(
else:
self._hparams = hp
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
"""Saves the model in ONNX format
@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, Path],
input_sample: Optional[Any] = None,
**kwargs,
):
"""
Saves the model in ONNX format
Args:
file_path: The path of the file the model should be saved to.
input_sample: A sample of an input tensor for tracing.
file_path: The path of the file the onnx model should be saved to.
input_sample: An input for tracing. Default: None (Use self.example_input_array)
**kwargs: Will be passed to torch.onnx.export function.
Example:
@ -1554,31 +1562,32 @@ class LightningModule(
... os.path.isfile(tmpfile.name)
True
"""
mode = self.training
if isinstance(input_sample, Tensor):
input_data = input_sample
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
if input_sample is not None:
if input_sample is None:
if self.example_input_array is None:
raise ValueError(
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
"Could not export to ONNX since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
raise ValueError(
"Could not export to ONNX since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
input_data = input_data.to(self.device)
input_sample = self.example_input_array
input_sample = self.transfer_batch_to_device(input_sample)
if "example_outputs" not in kwargs:
self.eval()
with torch.no_grad():
kwargs["example_outputs"] = self(input_data)
kwargs["example_outputs"] = self(input_sample)
torch.onnx.export(self, input_data, file_path, **kwargs)
torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)
@torch.no_grad()
def to_torchscript(
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
self,
file_path: Optional[Union[str, Path]] = None,
method: Optional[str] = 'script',
example_inputs: Optional[Any] = None,
**kwargs,
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
@ -1590,7 +1599,7 @@ class LightningModule(
Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
method: Whether to use TorchScript's script or trace method. Default: 'script'
example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
example_inputs: An input to be used to do tracing when method is set to 'trace'.
Default: None (Use self.example_input_array)
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
:func:`torch.jit.trace` function.
@ -1624,21 +1633,27 @@ class LightningModule(
This LightningModule as a torchscript, regardless of whether file_path is
defined or not.
"""
mode = self.training
with torch.no_grad():
if method == 'script':
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == 'trace':
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
example_inputs = self.example_input_array
# automatically send example inputs to the right device and use trace
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
f"{method}")
if method == 'script':
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == 'trace':
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
if self.example_input_array is None:
raise ValueError(
'Choosing method=`trace` requires either `example_inputs`'
' or `model.example_input_array` to be defined'
)
example_inputs = self.example_input_array
# automatically send example inputs to the right device and use trace
example_inputs = self.transfer_batch_to_device(example_inputs)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError("The 'method' parameter only supports 'script' or 'trace',"
f" but value given was: {method}")
self.train(mode)
if file_path is not None:

View File

@ -21,44 +21,44 @@ import torch
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
from tests.base import BoringModel, EvalModelTemplate
def test_model_saves_with_input_sample(tmpdir):
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_model_saves_on_gpu(tmpdir):
"""Test that model saves on gpu"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model)
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2
def test_model_saves_with_example_output(tmpdir):
"""Test that ONNX model saves when provided with example output"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.eval()
example_outputs = model.forward(input_sample)
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
@ -67,11 +67,13 @@ def test_model_saves_with_example_output(tmpdir):
def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = EvalModelTemplate()
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@ -100,7 +102,9 @@ def test_model_saves_on_multi_gpu(tmpdir):
def test_verbose_param(tmpdir, capsys):
"""Test that output is present when verbose parameter is set"""
model = EvalModelTemplate()
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path, verbose=True)
captured = capsys.readouterr()
@ -108,8 +112,8 @@ def test_verbose_param(tmpdir, capsys):
def test_error_if_no_input(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
"""Test that an error is thrown when there is no input tensor"""
model = BoringModel()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onnx")
with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor'
@ -117,21 +121,12 @@ def test_error_if_no_input(tmpdir):
model.to_onnx(file_path)
def test_error_if_input_sample_is_not_tensor(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = np.random.randn(1, 28 * 28)
with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is '
f'`Tensor`'):
model.to_onnx(file_path, input_sample)
def test_if_inference_output_is_valid(tmpdir):
"""Test that the output inferred from ONNX model is same as from PyTorch"""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=5)
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
trainer = Trainer(max_epochs=2)
trainer.fit(model)
model.eval()

View File

@ -16,43 +16,72 @@ from distutils.version import LooseVersion
import pytest
import torch
from tests.base import EvalModelTemplate
from tests.base import BoringModel
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.models import ParityModuleRNN, BasicGAN
@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
BoringModel,
ParityModuleRNN,
BasicGAN,
])
def test_torchscript_input_output(modelclass):
""" Test that scripted LightningModule forward works. """
model = modelclass()
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
model_output = model(model.example_input_array)
with torch.no_grad():
model_output = model(model.example_input_array)
script_output = script(model.example_input_array)
assert torch.allclose(script_output, model_output)
@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
BoringModel,
ParityModuleRNN,
BasicGAN,
])
def test_torchscript_input_output_trace(modelclass):
""" Test that traced LightningModule forward works. """
def test_torchscript_example_input_output_trace(modelclass):
""" Test that traced LightningModule forward works with example_input_array """
model = modelclass()
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)
script = model.to_torchscript(method='trace')
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
model_output = model(model.example_input_array)
with torch.no_grad():
model_output = model(model.example_input_array)
script_output = script(model.example_input_array)
assert torch.allclose(script_output, model_output)
def test_torchscript_input_output_trace():
""" Test that traced LightningModule forward works with example_inputs """
model = BoringModel()
example_inputs = torch.randn(1, 32)
script = model.to_torchscript(example_inputs=example_inputs, method='trace')
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
with torch.no_grad():
model_output = model(example_inputs)
script_output = script(example_inputs)
assert torch.allclose(script_output, model_output)
@pytest.mark.parametrize("device", [
torch.device("cpu"),
torch.device("cuda", 0)
@ -60,7 +89,9 @@ def test_torchscript_input_output_trace(modelclass):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
def test_torchscript_device(device):
""" Test that scripted module is on the correct device. """
model = EvalModelTemplate().to(device)
model = BoringModel().to(device)
model.example_input_array = torch.randn(5, 32)
script = model.to_torchscript()
assert next(script.parameters()).device == device
script_output = script(model.example_input_array.to(device))
@ -69,7 +100,7 @@ def test_torchscript_device(device):
def test_torchscript_retain_training_state():
""" Test that torchscript export does not alter the training mode of original model. """
model = EvalModelTemplate()
model = BoringModel()
model.train(True)
script = model.to_torchscript()
assert model.training
@ -81,7 +112,7 @@ def test_torchscript_retain_training_state():
@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
BoringModel,
ParityModuleRNN,
BasicGAN,
])
@ -100,7 +131,7 @@ def test_torchscript_properties(modelclass):
@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
BoringModel,
ParityModuleRNN,
BasicGAN,
])
@ -109,9 +140,27 @@ def test_torchscript_properties(modelclass):
reason="torch.save/load has bug loading script modules on torch <= 1.4",
)
def test_torchscript_save_load(tmpdir, modelclass):
""" Test that scripted LightningModules is correctly saved and can be loaded. """
""" Test that scripted LightningModule is correctly saved and can be loaded. """
model = modelclass()
output_file = str(tmpdir / "model.pt")
script = model.to_torchscript(file_path=output_file)
loaded_script = torch.jit.load(output_file)
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
def test_torchcript_invalid_method(tmpdir):
"""Test that an error is thrown with invalid torchscript method"""
model = BoringModel()
model.train(True)
with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
model.to_torchscript(method='temp')
def test_torchscript_with_no_input(tmpdir):
"""Test that an error is thrown when there is no input tensor"""
model = BoringModel()
model.example_input_array = None
with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'):
model.to_torchscript(method='trace')

View File

@ -958,6 +958,7 @@ def test_gradient_clipping(tmpdir):
"""
Test gradient clipping
"""
tutils.reset_seed()
model = EvalModelTemplate()
@ -995,6 +996,7 @@ def test_gradient_clipping_fp16(tmpdir):
"""
Test gradient clipping with fp16
"""
tutils.reset_seed()
model = EvalModelTemplate()