From 9669c80f293b0ef049d25a6bf8899533a42acaff Mon Sep 17 00:00:00 2001 From: chaton Date: Wed, 16 Dec 2020 22:07:35 +0100 Subject: [PATCH] [bugfix] remove nan loss in manual optimization (#5121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove nan loss whe missing * Update pytorch_lightning/core/lightning.py Co-authored-by: Carlos MocholĂ­ * Apply suggestions from code review Co-authored-by: Carlos MocholĂ­ Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cbdd86e24d..e8c19ec269 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1392,12 +1392,15 @@ class LightningModule( """ # call .item() only once but store elements without graphs running_train_loss = self.trainer.train_loop.running_loss.mean() - avg_training_loss = ( - running_train_loss.cpu().item() - if running_train_loss is not None - else float("NaN") - ) - tqdm_dict = {"loss": "{:.3g}".format(avg_training_loss)} + avg_training_loss = None + if running_train_loss is not None: + avg_training_loss = running_train_loss.cpu().item() + elif self.trainer.train_loop.automatic_optimization: + avg_training_loss = float('NaN') + + tqdm_dict = {} + if avg_training_loss is not None: + tqdm_dict["loss"] = f"{avg_training_loss:.3g}" if self.trainer.truncated_bptt_steps is not None: tqdm_dict["split_idx"] = self.trainer.split_idx