Detach hiddens and add test (#8249)
This commit is contained in:
parent
07b1ce227c
commit
8a7f504b6f
|
@ -29,6 +29,7 @@ from pytorch_lightning.plugins import ParallelPlugin
|
|||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
|
||||
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
|
||||
|
@ -345,7 +346,8 @@ class TrainingBatchLoop(Loop):
|
|||
if isinstance(training_step_output, dict):
|
||||
loss = training_step_output.pop("loss", None)
|
||||
hiddens = training_step_output.pop("hiddens", None)
|
||||
|
||||
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
|
||||
hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach())
|
||||
results.extra = training_step_output
|
||||
|
||||
# handle scalar return
|
||||
|
|
|
@ -261,7 +261,9 @@ def test_tbptt_log(tmpdir):
|
|||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
|
||||
self.test_hidden = torch.rand(1)
|
||||
if hiddens is not None:
|
||||
assert hiddens.grad_fn is None
|
||||
self.test_hidden = torch.tensor(2., requires_grad=True).pow(2)
|
||||
|
||||
x_tensor, y_list = batch
|
||||
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
|
||||
|
|
Loading…
Reference in New Issue