Make internal torchscript check a class attribute (#14904)

This commit is contained in:
Carlos Mocholí 2022-09-29 15:40:25 +02:00 committed by GitHub
parent 5f0c4aad12
commit 4eb7766f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 32 deletions

View File

@ -276,6 +276,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836))
- Fixed torchscript error with containers of LightningModules ([#14904](https://github.com/Lightning-AI/lightning/pull/14904))
## [1.7.7] - 2022-09-22
### Fixed

View File

@ -88,11 +88,11 @@ class LightningModule(
"automatic_optimization",
"truncated_bptt_steps",
"trainer",
"_running_torchscript",
]
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
+ HyperparametersMixin.__jit_unused_properties__
)
_jit_is_scripting = False
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
@ -115,8 +115,6 @@ class LightningModule(
self._param_requires_grad_state: Dict[str, bool] = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
self._running_torchscript_internal = False # workaround for https://github.com/pytorch/pytorch/issues/67146
self._register_sharded_tensor_state_dict_hooks_if_available()
@overload
@ -176,7 +174,7 @@ class LightningModule(
@property
def trainer(self) -> "pl.Trainer":
if not self._running_torchscript and self._trainer is None:
if not self._jit_is_scripting and self._trainer is None:
raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
return self._trainer # type: ignore[return-value]
@ -271,17 +269,6 @@ class LightningModule(
"""Reference to the list of loggers in the Trainer."""
return self.trainer.loggers if self._trainer else []
@property
def _running_torchscript(self) -> bool:
return self._running_torchscript_internal
@_running_torchscript.setter
def _running_torchscript(self, value: bool) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v._running_torchscript = value
self._running_torchscript_internal = value
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
if self._trainer:
datahook_selector = self._trainer._data_connector._datahook_selector
@ -1889,10 +1876,9 @@ class LightningModule(
"""
mode = self.training
self._running_torchscript = True
if method == "script":
torchscript_module = torch.jit.script(self.eval(), **kwargs)
with _jit_is_scripting():
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == "trace":
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
@ -1906,7 +1892,8 @@ class LightningModule(
# automatically send example inputs to the right device and use trace
example_inputs = self._on_before_batch_transfer(example_inputs)
example_inputs = self._apply_batch_transfer_handler(example_inputs)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
with _jit_is_scripting():
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")
@ -1917,8 +1904,6 @@ class LightningModule(
with fs.open(file_path, "wb") as f:
torch.jit.save(torchscript_module, f)
self._running_torchscript = False
return torchscript_module
@contextmanager
@ -1960,3 +1945,13 @@ class LightningModule(
self.__class__._register_load_state_dict_pre_hook(
weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type]
)
@contextmanager
def _jit_is_scripting() -> Generator:
"""Workaround for https://github.com/pytorch/pytorch/issues/67146."""
LightningModule._jit_is_scripting = True
try:
yield
finally:
LightningModule._jit_is_scripting = False

View File

@ -242,9 +242,8 @@ class TensorBoardLogger(Logger):
if input_array is not None:
input_array = model._on_before_batch_transfer(input_array)
input_array = model._apply_batch_transfer_handler(input_array)
model._running_torchscript = True
self.experiment.add_graph(model, input_array)
model._running_torchscript = False
with pl.core.module._jit_is_scripting():
self.experiment.add_graph(model, input_array)
else:
rank_zero_warn(
"Could not log computational graph since the"

View File

@ -275,12 +275,9 @@ class IPUStrategy(ParallelStrategy):
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
args = self._prepare_input(args)
assert self.lightning_module is not None
poptorch_model = self.poptorch_models[stage]
self.lightning_module._running_torchscript = True
out = poptorch_model(*args, **kwargs)
self.lightning_module._running_torchscript = False
return out
with pl.core.module._jit_is_scripting():
return poptorch_model(*args, **kwargs)
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():

View File

@ -155,7 +155,7 @@ def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass):
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
def test_torchcript_invalid_method(tmpdir):
def test_torchcript_invalid_method():
"""Test that an error is thrown with invalid torchscript method."""
model = BoringModel()
model.train(True)
@ -164,7 +164,7 @@ def test_torchcript_invalid_method(tmpdir):
model.to_torchscript(method="temp")
def test_torchscript_with_no_input(tmpdir):
def test_torchscript_with_no_input():
"""Test that an error is thrown when there is no input tensor."""
model = BoringModel()
model.example_input_array = None
@ -185,7 +185,7 @@ def test_torchscript_script_recursively():
class Child(LightningModule):
def __init__(self):
super().__init__()
self.model = GrandChild()
self.model = torch.nn.Sequential(GrandChild(), GrandChild())
def forward(self, inputs):
return self.model(inputs)
@ -199,5 +199,7 @@ def test_torchscript_script_recursively():
return self.model(inputs)
lm = Parent()
assert not lm._jit_is_scripting
script = lm.to_torchscript(method="script")
assert not lm._jit_is_scripting
assert isinstance(script, torch.jit.RecursiveScriptModule)