move example inputs to correct device when tracing module (#4360)

* use move_data_to_device instead of to; docstring also allow tuple of Tensor; not supported log error when example_inputs is a dict; commented docstring trace example

* Use isinstance to check if example_inputs is a Mapping, instead of type

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* import Mapping for isinstance check

* multi-line docstring code to test TorchScript trace()

* Fix PEP8 f-string is missing placeholders

* minor code style improvements

* Use (possibly user overwritten) transfer_batch_to_device instead of move_data_to_device

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* fixed weird comment about trace() log error

* Remove unused import

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>

* Remove logger warning about dict not example_inputs not supported by trace

Co-authored-by: stef-ubuntu <stef@webempath.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
Stef | ステフ 2020-10-29 14:46:57 +09:00 committed by GitHub
parent ff41d80706
commit 9cfd29946a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 4 deletions

View File

@ -20,7 +20,7 @@ import re
import tempfile
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping
import torch
from pytorch_lightning import _logger as log
@ -1539,7 +1539,7 @@ class LightningModule(
def to_torchscript(
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
example_inputs: Optional[torch.Tensor] = None, **kwargs
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
@ -1576,6 +1576,9 @@ class LightningModule(
>>> model = SimpleModel()
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
>>> os.path.isfile("model.pt") # doctest: +SKIP
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
... example_inputs=torch.randn(1, 64))) # doctest: +SKIP
>>> os.path.isfile("model_trace.pt") # doctest: +SKIP
True
Return:
@ -1592,8 +1595,8 @@ class LightningModule(
if example_inputs is None:
example_inputs = self.example_input_array
# automatically send example inputs to the right device and use trace
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device),
**kwargs)
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
f"{method}")