From 8a7f504b6f84e64b02b3dd03133c1859fc4d56e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 2 Jul 2021 14:03:12 +0200 Subject: [PATCH] Detach hiddens and add test (#8249) --- pytorch_lightning/loops/batch/training_batch_loop.py | 4 +++- tests/trainer/logging_/test_train_loop_logging.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index bd73dcfc12..64df877ce6 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -29,6 +29,7 @@ from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE @@ -345,7 +346,8 @@ class TrainingBatchLoop(Loop): if isinstance(training_step_output, dict): loss = training_step_output.pop("loss", None) hiddens = training_step_output.pop("hiddens", None) - + # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` + hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach()) results.extra = training_step_output # handle scalar return diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 9ea4eda5bc..b26e3fc83d 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -261,7 +261,9 @@ def test_tbptt_log(tmpdir): def training_step(self, batch, batch_idx, hiddens): assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - self.test_hidden = torch.rand(1) + if hiddens is not None: + assert hiddens.grad_fn is None + self.test_hidden = torch.tensor(2., requires_grad=True).pow(2) x_tensor, y_list = batch assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"