diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index c6868311cd..8ccf3b30b4 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -276,6 +276,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836)) +- Fixed torchscript error with containers of LightningModules ([#14904](https://github.com/Lightning-AI/lightning/pull/14904)) + + + ## [1.7.7] - 2022-09-22 ### Fixed diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 428ed58ea7..51af945c78 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -88,11 +88,11 @@ class LightningModule( "automatic_optimization", "truncated_bptt_steps", "trainer", - "_running_torchscript", ] + _DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__ ) + _jit_is_scripting = False def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -115,8 +115,6 @@ class LightningModule( self._param_requires_grad_state: Dict[str, bool] = {} self._metric_attributes: Optional[Dict[int, str]] = None self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False - self._running_torchscript_internal = False # workaround for https://github.com/pytorch/pytorch/issues/67146 - self._register_sharded_tensor_state_dict_hooks_if_available() @overload @@ -176,7 +174,7 @@ class LightningModule( @property def trainer(self) -> "pl.Trainer": - if not self._running_torchscript and self._trainer is None: + if not self._jit_is_scripting and self._trainer is None: raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.") return self._trainer # type: ignore[return-value] @@ -271,17 +269,6 @@ class LightningModule( """Reference to the list of loggers in the Trainer.""" return self.trainer.loggers if self._trainer else [] - @property - def _running_torchscript(self) -> bool: - return self._running_torchscript_internal - - @_running_torchscript.setter - def _running_torchscript(self, value: bool) -> None: - for v in self.children(): - if isinstance(v, LightningModule): - v._running_torchscript = value - self._running_torchscript_internal = value - def _call_batch_hook(self, hook_name: str, *args: Any) -> Any: if self._trainer: datahook_selector = self._trainer._data_connector._datahook_selector @@ -1889,10 +1876,9 @@ class LightningModule( """ mode = self.training - self._running_torchscript = True - if method == "script": - torchscript_module = torch.jit.script(self.eval(), **kwargs) + with _jit_is_scripting(): + torchscript_module = torch.jit.script(self.eval(), **kwargs) elif method == "trace": # if no example inputs are provided, try to see if model has example_input_array set if example_inputs is None: @@ -1906,7 +1892,8 @@ class LightningModule( # automatically send example inputs to the right device and use trace example_inputs = self._on_before_batch_transfer(example_inputs) example_inputs = self._apply_batch_transfer_handler(example_inputs) - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) + with _jit_is_scripting(): + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}") @@ -1917,8 +1904,6 @@ class LightningModule( with fs.open(file_path, "wb") as f: torch.jit.save(torchscript_module, f) - self._running_torchscript = False - return torchscript_module @contextmanager @@ -1960,3 +1945,13 @@ class LightningModule( self.__class__._register_load_state_dict_pre_hook( weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type] ) + + +@contextmanager +def _jit_is_scripting() -> Generator: + """Workaround for https://github.com/pytorch/pytorch/issues/67146.""" + LightningModule._jit_is_scripting = True + try: + yield + finally: + LightningModule._jit_is_scripting = False diff --git a/src/pytorch_lightning/loggers/tensorboard.py b/src/pytorch_lightning/loggers/tensorboard.py index c5cfce12eb..c68d85d056 100644 --- a/src/pytorch_lightning/loggers/tensorboard.py +++ b/src/pytorch_lightning/loggers/tensorboard.py @@ -242,9 +242,8 @@ class TensorBoardLogger(Logger): if input_array is not None: input_array = model._on_before_batch_transfer(input_array) input_array = model._apply_batch_transfer_handler(input_array) - model._running_torchscript = True - self.experiment.add_graph(model, input_array) - model._running_torchscript = False + with pl.core.module._jit_is_scripting(): + self.experiment.add_graph(model, input_array) else: rank_zero_warn( "Could not log computational graph since the" diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 7f83e7dff5..a5daa13dc9 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -275,12 +275,9 @@ class IPUStrategy(ParallelStrategy): def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT: args = self._prepare_input(args) - assert self.lightning_module is not None poptorch_model = self.poptorch_models[stage] - self.lightning_module._running_torchscript = True - out = poptorch_model(*args, **kwargs) - self.lightning_module._running_torchscript = False - return out + with pl.core.module._jit_is_scripting(): + return poptorch_model(*args, **kwargs) def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.train_step_context(): diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 857307a0f5..00bd75a595 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -155,7 +155,7 @@ def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass): assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) -def test_torchcript_invalid_method(tmpdir): +def test_torchcript_invalid_method(): """Test that an error is thrown with invalid torchscript method.""" model = BoringModel() model.train(True) @@ -164,7 +164,7 @@ def test_torchcript_invalid_method(tmpdir): model.to_torchscript(method="temp") -def test_torchscript_with_no_input(tmpdir): +def test_torchscript_with_no_input(): """Test that an error is thrown when there is no input tensor.""" model = BoringModel() model.example_input_array = None @@ -185,7 +185,7 @@ def test_torchscript_script_recursively(): class Child(LightningModule): def __init__(self): super().__init__() - self.model = GrandChild() + self.model = torch.nn.Sequential(GrandChild(), GrandChild()) def forward(self, inputs): return self.model(inputs) @@ -199,5 +199,7 @@ def test_torchscript_script_recursively(): return self.model(inputs) lm = Parent() + assert not lm._jit_is_scripting script = lm.to_torchscript(method="script") + assert not lm._jit_is_scripting assert isinstance(script, torch.jit.RecursiveScriptModule)