Allow any input in to_onnx and to_torchscript (#4378)
* branch merge * sample * update with valid input tensors * pep * pathlib * Updated with BoringModel and added more input types * try fix * pep * skip test with torch < 1.4 * fix test * Apply suggestions from code review * update tests * Allow any input in to_onnx and to_torchscript * Update tests/models/test_torchscript.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * no_grad * try fix random failing test * rm example_input_array * rm example_input_array Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jeff Yang <ydcjeff@outlook.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com>
This commit is contained in:
parent
b5a2afd232
commit
3100b7839a
|
@ -14,7 +14,7 @@
|
|||
|
||||
"""Various hooks to be used in the Lightning code."""
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
|
||||
|
@ -501,7 +501,7 @@ class DataHooks:
|
|||
will have an argument ``dataloader_idx`` which matches the order here.
|
||||
"""
|
||||
|
||||
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
|
||||
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
|
||||
"""
|
||||
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
|
||||
wrapped in a custom data structure.
|
||||
|
@ -549,6 +549,7 @@ class DataHooks:
|
|||
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
|
||||
"""
|
||||
device = device or self.device
|
||||
return move_data_to_device(batch, device)
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import re
|
|||
import tempfile
|
||||
from abc import ABC
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -1530,12 +1531,19 @@ class LightningModule(
|
|||
else:
|
||||
self._hparams = hp
|
||||
|
||||
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
|
||||
"""Saves the model in ONNX format
|
||||
@torch.no_grad()
|
||||
def to_onnx(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
input_sample: Optional[Any] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Saves the model in ONNX format
|
||||
|
||||
Args:
|
||||
file_path: The path of the file the model should be saved to.
|
||||
input_sample: A sample of an input tensor for tracing.
|
||||
file_path: The path of the file the onnx model should be saved to.
|
||||
input_sample: An input for tracing. Default: None (Use self.example_input_array)
|
||||
**kwargs: Will be passed to torch.onnx.export function.
|
||||
|
||||
Example:
|
||||
|
@ -1554,31 +1562,32 @@ class LightningModule(
|
|||
... os.path.isfile(tmpfile.name)
|
||||
True
|
||||
"""
|
||||
mode = self.training
|
||||
|
||||
if isinstance(input_sample, Tensor):
|
||||
input_data = input_sample
|
||||
elif self.example_input_array is not None:
|
||||
input_data = self.example_input_array
|
||||
else:
|
||||
if input_sample is not None:
|
||||
if input_sample is None:
|
||||
if self.example_input_array is None:
|
||||
raise ValueError(
|
||||
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
|
||||
"Could not export to ONNX since neither `input_sample` nor"
|
||||
" `model.example_input_array` attribute is set."
|
||||
)
|
||||
raise ValueError(
|
||||
"Could not export to ONNX since neither `input_sample` nor"
|
||||
" `model.example_input_array` attribute is set."
|
||||
)
|
||||
input_data = input_data.to(self.device)
|
||||
input_sample = self.example_input_array
|
||||
|
||||
input_sample = self.transfer_batch_to_device(input_sample)
|
||||
|
||||
if "example_outputs" not in kwargs:
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
kwargs["example_outputs"] = self(input_data)
|
||||
kwargs["example_outputs"] = self(input_sample)
|
||||
|
||||
torch.onnx.export(self, input_data, file_path, **kwargs)
|
||||
torch.onnx.export(self, input_sample, file_path, **kwargs)
|
||||
self.train(mode)
|
||||
|
||||
@torch.no_grad()
|
||||
def to_torchscript(
|
||||
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
|
||||
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
|
||||
self,
|
||||
file_path: Optional[Union[str, Path]] = None,
|
||||
method: Optional[str] = 'script',
|
||||
example_inputs: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
|
||||
"""
|
||||
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
|
||||
|
@ -1590,7 +1599,7 @@ class LightningModule(
|
|||
Args:
|
||||
file_path: Path where to save the torchscript. Default: None (no file saved).
|
||||
method: Whether to use TorchScript's script or trace method. Default: 'script'
|
||||
example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
|
||||
example_inputs: An input to be used to do tracing when method is set to 'trace'.
|
||||
Default: None (Use self.example_input_array)
|
||||
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
|
||||
:func:`torch.jit.trace` function.
|
||||
|
@ -1624,21 +1633,27 @@ class LightningModule(
|
|||
This LightningModule as a torchscript, regardless of whether file_path is
|
||||
defined or not.
|
||||
"""
|
||||
|
||||
mode = self.training
|
||||
with torch.no_grad():
|
||||
if method == 'script':
|
||||
torchscript_module = torch.jit.script(self.eval(), **kwargs)
|
||||
elif method == 'trace':
|
||||
# if no example inputs are provided, try to see if model has example_input_array set
|
||||
if example_inputs is None:
|
||||
example_inputs = self.example_input_array
|
||||
# automatically send example inputs to the right device and use trace
|
||||
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
|
||||
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
|
||||
f"{method}")
|
||||
|
||||
if method == 'script':
|
||||
torchscript_module = torch.jit.script(self.eval(), **kwargs)
|
||||
elif method == 'trace':
|
||||
# if no example inputs are provided, try to see if model has example_input_array set
|
||||
if example_inputs is None:
|
||||
if self.example_input_array is None:
|
||||
raise ValueError(
|
||||
'Choosing method=`trace` requires either `example_inputs`'
|
||||
' or `model.example_input_array` to be defined'
|
||||
)
|
||||
example_inputs = self.example_input_array
|
||||
|
||||
# automatically send example inputs to the right device and use trace
|
||||
example_inputs = self.transfer_batch_to_device(example_inputs)
|
||||
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError("The 'method' parameter only supports 'script' or 'trace',"
|
||||
f" but value given was: {method}")
|
||||
|
||||
self.train(mode)
|
||||
|
||||
if file_path is not None:
|
||||
|
|
|
@ -21,44 +21,44 @@ import torch
|
|||
import tests.base.develop_pipelines as tpipes
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base import BoringModel, EvalModelTemplate
|
||||
|
||||
|
||||
def test_model_saves_with_input_sample(tmpdir):
|
||||
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
|
||||
model = EvalModelTemplate()
|
||||
model = BoringModel()
|
||||
trainer = Trainer(max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
input_sample = torch.randn((1, 28 * 28))
|
||||
input_sample = torch.randn((1, 32))
|
||||
model.to_onnx(file_path, input_sample)
|
||||
assert os.path.isfile(file_path)
|
||||
assert os.path.getsize(file_path) > 3e+06
|
||||
assert os.path.getsize(file_path) > 4e2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_model_saves_on_gpu(tmpdir):
|
||||
"""Test that model saves on gpu"""
|
||||
model = EvalModelTemplate()
|
||||
model = BoringModel()
|
||||
trainer = Trainer(gpus=1, max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
input_sample = torch.randn((1, 28 * 28))
|
||||
input_sample = torch.randn((1, 32))
|
||||
model.to_onnx(file_path, input_sample)
|
||||
assert os.path.isfile(file_path)
|
||||
assert os.path.getsize(file_path) > 3e+06
|
||||
assert os.path.getsize(file_path) > 4e2
|
||||
|
||||
|
||||
def test_model_saves_with_example_output(tmpdir):
|
||||
"""Test that ONNX model saves when provided with example output"""
|
||||
model = EvalModelTemplate()
|
||||
model = BoringModel()
|
||||
trainer = Trainer(max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
input_sample = torch.randn((1, 28 * 28))
|
||||
input_sample = torch.randn((1, 32))
|
||||
model.eval()
|
||||
example_outputs = model.forward(input_sample)
|
||||
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
|
||||
|
@ -67,11 +67,13 @@ def test_model_saves_with_example_output(tmpdir):
|
|||
|
||||
def test_model_saves_with_example_input_array(tmpdir):
|
||||
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
|
||||
model = EvalModelTemplate()
|
||||
model = BoringModel()
|
||||
model.example_input_array = torch.randn(5, 32)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
model.to_onnx(file_path)
|
||||
assert os.path.exists(file_path) is True
|
||||
assert os.path.getsize(file_path) > 3e+06
|
||||
assert os.path.getsize(file_path) > 4e2
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
|
@ -100,7 +102,9 @@ def test_model_saves_on_multi_gpu(tmpdir):
|
|||
|
||||
def test_verbose_param(tmpdir, capsys):
|
||||
"""Test that output is present when verbose parameter is set"""
|
||||
model = EvalModelTemplate()
|
||||
model = BoringModel()
|
||||
model.example_input_array = torch.randn(5, 32)
|
||||
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
model.to_onnx(file_path, verbose=True)
|
||||
captured = capsys.readouterr()
|
||||
|
@ -108,8 +112,8 @@ def test_verbose_param(tmpdir, capsys):
|
|||
|
||||
|
||||
def test_error_if_no_input(tmpdir):
|
||||
"""Test that an exception is thrown when there is no input tensor"""
|
||||
model = EvalModelTemplate()
|
||||
"""Test that an error is thrown when there is no input tensor"""
|
||||
model = BoringModel()
|
||||
model.example_input_array = None
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor'
|
||||
|
@ -117,21 +121,12 @@ def test_error_if_no_input(tmpdir):
|
|||
model.to_onnx(file_path)
|
||||
|
||||
|
||||
def test_error_if_input_sample_is_not_tensor(tmpdir):
|
||||
"""Test that an exception is thrown when there is no input tensor"""
|
||||
model = EvalModelTemplate()
|
||||
model.example_input_array = None
|
||||
file_path = os.path.join(tmpdir, "model.onnx")
|
||||
input_sample = np.random.randn(1, 28 * 28)
|
||||
with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is '
|
||||
f'`Tensor`'):
|
||||
model.to_onnx(file_path, input_sample)
|
||||
|
||||
|
||||
def test_if_inference_output_is_valid(tmpdir):
|
||||
"""Test that the output inferred from ONNX model is same as from PyTorch"""
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(max_epochs=5)
|
||||
model = BoringModel()
|
||||
model.example_input_array = torch.randn(5, 32)
|
||||
|
||||
trainer = Trainer(max_epochs=2)
|
||||
trainer.fit(model)
|
||||
|
||||
model.eval()
|
||||
|
|
|
@ -16,43 +16,72 @@ from distutils.version import LooseVersion
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base import BoringModel
|
||||
from tests.base.datamodules import TrialMNISTDataModule
|
||||
from tests.base.models import ParityModuleRNN, BasicGAN
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modelclass", [
|
||||
EvalModelTemplate,
|
||||
BoringModel,
|
||||
ParityModuleRNN,
|
||||
BasicGAN,
|
||||
])
|
||||
def test_torchscript_input_output(modelclass):
|
||||
""" Test that scripted LightningModule forward works. """
|
||||
model = modelclass()
|
||||
|
||||
if isinstance(model, BoringModel):
|
||||
model.example_input_array = torch.randn(5, 32)
|
||||
|
||||
script = model.to_torchscript()
|
||||
assert isinstance(script, torch.jit.ScriptModule)
|
||||
|
||||
model.eval()
|
||||
model_output = model(model.example_input_array)
|
||||
with torch.no_grad():
|
||||
model_output = model(model.example_input_array)
|
||||
|
||||
script_output = script(model.example_input_array)
|
||||
assert torch.allclose(script_output, model_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modelclass", [
|
||||
EvalModelTemplate,
|
||||
BoringModel,
|
||||
ParityModuleRNN,
|
||||
BasicGAN,
|
||||
])
|
||||
def test_torchscript_input_output_trace(modelclass):
|
||||
""" Test that traced LightningModule forward works. """
|
||||
def test_torchscript_example_input_output_trace(modelclass):
|
||||
""" Test that traced LightningModule forward works with example_input_array """
|
||||
model = modelclass()
|
||||
|
||||
if isinstance(model, BoringModel):
|
||||
model.example_input_array = torch.randn(5, 32)
|
||||
|
||||
script = model.to_torchscript(method='trace')
|
||||
assert isinstance(script, torch.jit.ScriptModule)
|
||||
|
||||
model.eval()
|
||||
model_output = model(model.example_input_array)
|
||||
with torch.no_grad():
|
||||
model_output = model(model.example_input_array)
|
||||
|
||||
script_output = script(model.example_input_array)
|
||||
assert torch.allclose(script_output, model_output)
|
||||
|
||||
|
||||
def test_torchscript_input_output_trace():
|
||||
""" Test that traced LightningModule forward works with example_inputs """
|
||||
model = BoringModel()
|
||||
example_inputs = torch.randn(1, 32)
|
||||
script = model.to_torchscript(example_inputs=example_inputs, method='trace')
|
||||
assert isinstance(script, torch.jit.ScriptModule)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
model_output = model(example_inputs)
|
||||
|
||||
script_output = script(example_inputs)
|
||||
assert torch.allclose(script_output, model_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", [
|
||||
torch.device("cpu"),
|
||||
torch.device("cuda", 0)
|
||||
|
@ -60,7 +89,9 @@ def test_torchscript_input_output_trace(modelclass):
|
|||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
def test_torchscript_device(device):
|
||||
""" Test that scripted module is on the correct device. """
|
||||
model = EvalModelTemplate().to(device)
|
||||
model = BoringModel().to(device)
|
||||
model.example_input_array = torch.randn(5, 32)
|
||||
|
||||
script = model.to_torchscript()
|
||||
assert next(script.parameters()).device == device
|
||||
script_output = script(model.example_input_array.to(device))
|
||||
|
@ -69,7 +100,7 @@ def test_torchscript_device(device):
|
|||
|
||||
def test_torchscript_retain_training_state():
|
||||
""" Test that torchscript export does not alter the training mode of original model. """
|
||||
model = EvalModelTemplate()
|
||||
model = BoringModel()
|
||||
model.train(True)
|
||||
script = model.to_torchscript()
|
||||
assert model.training
|
||||
|
@ -81,7 +112,7 @@ def test_torchscript_retain_training_state():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("modelclass", [
|
||||
EvalModelTemplate,
|
||||
BoringModel,
|
||||
ParityModuleRNN,
|
||||
BasicGAN,
|
||||
])
|
||||
|
@ -100,7 +131,7 @@ def test_torchscript_properties(modelclass):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("modelclass", [
|
||||
EvalModelTemplate,
|
||||
BoringModel,
|
||||
ParityModuleRNN,
|
||||
BasicGAN,
|
||||
])
|
||||
|
@ -109,9 +140,27 @@ def test_torchscript_properties(modelclass):
|
|||
reason="torch.save/load has bug loading script modules on torch <= 1.4",
|
||||
)
|
||||
def test_torchscript_save_load(tmpdir, modelclass):
|
||||
""" Test that scripted LightningModules is correctly saved and can be loaded. """
|
||||
""" Test that scripted LightningModule is correctly saved and can be loaded. """
|
||||
model = modelclass()
|
||||
output_file = str(tmpdir / "model.pt")
|
||||
script = model.to_torchscript(file_path=output_file)
|
||||
loaded_script = torch.jit.load(output_file)
|
||||
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
|
||||
|
||||
|
||||
def test_torchcript_invalid_method(tmpdir):
|
||||
"""Test that an error is thrown with invalid torchscript method"""
|
||||
model = BoringModel()
|
||||
model.train(True)
|
||||
|
||||
with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
|
||||
model.to_torchscript(method='temp')
|
||||
|
||||
|
||||
def test_torchscript_with_no_input(tmpdir):
|
||||
"""Test that an error is thrown when there is no input tensor"""
|
||||
model = BoringModel()
|
||||
model.example_input_array = None
|
||||
|
||||
with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'):
|
||||
model.to_torchscript(method='trace')
|
||||
|
|
|
@ -958,6 +958,7 @@ def test_gradient_clipping(tmpdir):
|
|||
"""
|
||||
Test gradient clipping
|
||||
"""
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
|
@ -995,6 +996,7 @@ def test_gradient_clipping_fp16(tmpdir):
|
|||
"""
|
||||
Test gradient clipping with fp16
|
||||
"""
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
|
|
Loading…
Reference in New Issue