move example inputs to correct device when tracing module (#4360)
* use move_data_to_device instead of to; docstring also allow tuple of Tensor; not supported log error when example_inputs is a dict; commented docstring trace example * Use isinstance to check if example_inputs is a Mapping, instead of type Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * import Mapping for isinstance check * multi-line docstring code to test TorchScript trace() * Fix PEP8 f-string is missing placeholders * minor code style improvements * Use (possibly user overwritten) transfer_batch_to_device instead of move_data_to_device Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * fixed weird comment about trace() log error * Remove unused import Co-authored-by: Jeff Yang <ydcjeff@outlook.com> * Remove logger warning about dict not example_inputs not supported by trace Co-authored-by: stef-ubuntu <stef@webempath.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
parent
ff41d80706
commit
9cfd29946a
|
@ -20,7 +20,7 @@ import re
|
|||
import tempfile
|
||||
from abc import ABC
|
||||
from argparse import Namespace
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import _logger as log
|
||||
|
@ -1539,7 +1539,7 @@ class LightningModule(
|
|||
|
||||
def to_torchscript(
|
||||
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
|
||||
example_inputs: Optional[torch.Tensor] = None, **kwargs
|
||||
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
|
||||
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
|
||||
"""
|
||||
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
|
||||
|
@ -1576,6 +1576,9 @@ class LightningModule(
|
|||
>>> model = SimpleModel()
|
||||
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
|
||||
>>> os.path.isfile("model.pt") # doctest: +SKIP
|
||||
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
|
||||
... example_inputs=torch.randn(1, 64))) # doctest: +SKIP
|
||||
>>> os.path.isfile("model_trace.pt") # doctest: +SKIP
|
||||
True
|
||||
|
||||
Return:
|
||||
|
@ -1592,8 +1595,8 @@ class LightningModule(
|
|||
if example_inputs is None:
|
||||
example_inputs = self.example_input_array
|
||||
# automatically send example inputs to the right device and use trace
|
||||
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device),
|
||||
**kwargs)
|
||||
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
|
||||
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
|
||||
f"{method}")
|
||||
|
|
Loading…
Reference in New Issue