diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 0f9fc1fd42..a11eaceee1 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -1475,6 +1475,11 @@ with the hidden # hiddens are the hiddens from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) + # remember to detach() hiddens. + # If you don't, you will get a RuntimeError: Trying to backward through + # the graph a second time... + # Using hiddens.detach() allows each split to be disconnected. + return { "loss": ..., "hiddens": hiddens # remember to detach() this @@ -1702,4 +1707,3 @@ The metrics sent to the progress bar. progress_bar_metrics = trainer.progress_bar_metrics assert progress_bar_metrics['a_val'] == 2 -