Detach hiddens and add test (#8249)

This commit is contained in:
Carlos Mocholí 2021-07-02 14:03:12 +02:00 committed by GitHub
parent 07b1ce227c
commit 8a7f504b6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 2 deletions

View File

@ -29,6 +29,7 @@ from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
@ -345,7 +346,8 @@ class TrainingBatchLoop(Loop):
if isinstance(training_step_output, dict):
loss = training_step_output.pop("loss", None)
hiddens = training_step_output.pop("hiddens", None)
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach())
results.extra = training_step_output
# handle scalar return

View File

@ -261,7 +261,9 @@ def test_tbptt_log(tmpdir):
def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
self.test_hidden = torch.rand(1)
if hiddens is not None:
assert hiddens.grad_fn is None
self.test_hidden = torch.tensor(2., requires_grad=True).pow(2)
x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"