Make internal torchscript check a class attribute (#14904)
This commit is contained in:
parent
5f0c4aad12
commit
4eb7766f3c
|
@ -276,6 +276,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836))
|
||||
|
||||
|
||||
- Fixed torchscript error with containers of LightningModules ([#14904](https://github.com/Lightning-AI/lightning/pull/14904))
|
||||
|
||||
|
||||
|
||||
## [1.7.7] - 2022-09-22
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -88,11 +88,11 @@ class LightningModule(
|
|||
"automatic_optimization",
|
||||
"truncated_bptt_steps",
|
||||
"trainer",
|
||||
"_running_torchscript",
|
||||
]
|
||||
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
|
||||
+ HyperparametersMixin.__jit_unused_properties__
|
||||
)
|
||||
_jit_is_scripting = False
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -115,8 +115,6 @@ class LightningModule(
|
|||
self._param_requires_grad_state: Dict[str, bool] = {}
|
||||
self._metric_attributes: Optional[Dict[int, str]] = None
|
||||
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
|
||||
self._running_torchscript_internal = False # workaround for https://github.com/pytorch/pytorch/issues/67146
|
||||
|
||||
self._register_sharded_tensor_state_dict_hooks_if_available()
|
||||
|
||||
@overload
|
||||
|
@ -176,7 +174,7 @@ class LightningModule(
|
|||
|
||||
@property
|
||||
def trainer(self) -> "pl.Trainer":
|
||||
if not self._running_torchscript and self._trainer is None:
|
||||
if not self._jit_is_scripting and self._trainer is None:
|
||||
raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
|
||||
return self._trainer # type: ignore[return-value]
|
||||
|
||||
|
@ -271,17 +269,6 @@ class LightningModule(
|
|||
"""Reference to the list of loggers in the Trainer."""
|
||||
return self.trainer.loggers if self._trainer else []
|
||||
|
||||
@property
|
||||
def _running_torchscript(self) -> bool:
|
||||
return self._running_torchscript_internal
|
||||
|
||||
@_running_torchscript.setter
|
||||
def _running_torchscript(self, value: bool) -> None:
|
||||
for v in self.children():
|
||||
if isinstance(v, LightningModule):
|
||||
v._running_torchscript = value
|
||||
self._running_torchscript_internal = value
|
||||
|
||||
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
|
||||
if self._trainer:
|
||||
datahook_selector = self._trainer._data_connector._datahook_selector
|
||||
|
@ -1889,10 +1876,9 @@ class LightningModule(
|
|||
"""
|
||||
mode = self.training
|
||||
|
||||
self._running_torchscript = True
|
||||
|
||||
if method == "script":
|
||||
torchscript_module = torch.jit.script(self.eval(), **kwargs)
|
||||
with _jit_is_scripting():
|
||||
torchscript_module = torch.jit.script(self.eval(), **kwargs)
|
||||
elif method == "trace":
|
||||
# if no example inputs are provided, try to see if model has example_input_array set
|
||||
if example_inputs is None:
|
||||
|
@ -1906,7 +1892,8 @@ class LightningModule(
|
|||
# automatically send example inputs to the right device and use trace
|
||||
example_inputs = self._on_before_batch_transfer(example_inputs)
|
||||
example_inputs = self._apply_batch_transfer_handler(example_inputs)
|
||||
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
|
||||
with _jit_is_scripting():
|
||||
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")
|
||||
|
||||
|
@ -1917,8 +1904,6 @@ class LightningModule(
|
|||
with fs.open(file_path, "wb") as f:
|
||||
torch.jit.save(torchscript_module, f)
|
||||
|
||||
self._running_torchscript = False
|
||||
|
||||
return torchscript_module
|
||||
|
||||
@contextmanager
|
||||
|
@ -1960,3 +1945,13 @@ class LightningModule(
|
|||
self.__class__._register_load_state_dict_pre_hook(
|
||||
weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _jit_is_scripting() -> Generator:
|
||||
"""Workaround for https://github.com/pytorch/pytorch/issues/67146."""
|
||||
LightningModule._jit_is_scripting = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
LightningModule._jit_is_scripting = False
|
||||
|
|
|
@ -242,9 +242,8 @@ class TensorBoardLogger(Logger):
|
|||
if input_array is not None:
|
||||
input_array = model._on_before_batch_transfer(input_array)
|
||||
input_array = model._apply_batch_transfer_handler(input_array)
|
||||
model._running_torchscript = True
|
||||
self.experiment.add_graph(model, input_array)
|
||||
model._running_torchscript = False
|
||||
with pl.core.module._jit_is_scripting():
|
||||
self.experiment.add_graph(model, input_array)
|
||||
else:
|
||||
rank_zero_warn(
|
||||
"Could not log computational graph since the"
|
||||
|
|
|
@ -275,12 +275,9 @@ class IPUStrategy(ParallelStrategy):
|
|||
|
||||
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
||||
args = self._prepare_input(args)
|
||||
assert self.lightning_module is not None
|
||||
poptorch_model = self.poptorch_models[stage]
|
||||
self.lightning_module._running_torchscript = True
|
||||
out = poptorch_model(*args, **kwargs)
|
||||
self.lightning_module._running_torchscript = False
|
||||
return out
|
||||
with pl.core.module._jit_is_scripting():
|
||||
return poptorch_model(*args, **kwargs)
|
||||
|
||||
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
||||
with self.precision_plugin.train_step_context():
|
||||
|
|
|
@ -155,7 +155,7 @@ def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass):
|
|||
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
|
||||
|
||||
|
||||
def test_torchcript_invalid_method(tmpdir):
|
||||
def test_torchcript_invalid_method():
|
||||
"""Test that an error is thrown with invalid torchscript method."""
|
||||
model = BoringModel()
|
||||
model.train(True)
|
||||
|
@ -164,7 +164,7 @@ def test_torchcript_invalid_method(tmpdir):
|
|||
model.to_torchscript(method="temp")
|
||||
|
||||
|
||||
def test_torchscript_with_no_input(tmpdir):
|
||||
def test_torchscript_with_no_input():
|
||||
"""Test that an error is thrown when there is no input tensor."""
|
||||
model = BoringModel()
|
||||
model.example_input_array = None
|
||||
|
@ -185,7 +185,7 @@ def test_torchscript_script_recursively():
|
|||
class Child(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = GrandChild()
|
||||
self.model = torch.nn.Sequential(GrandChild(), GrandChild())
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model(inputs)
|
||||
|
@ -199,5 +199,7 @@ def test_torchscript_script_recursively():
|
|||
return self.model(inputs)
|
||||
|
||||
lm = Parent()
|
||||
assert not lm._jit_is_scripting
|
||||
script = lm.to_torchscript(method="script")
|
||||
assert not lm._jit_is_scripting
|
||||
assert isinstance(script, torch.jit.RecursiveScriptModule)
|
||||
|
|
Loading…
Reference in New Issue