[bugfix] remove nan loss in manual optimization (#5121)

* remove nan loss whe missing

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
chaton 2020-12-16 22:07:35 +01:00 committed by Jirka Borovec
parent 13bbf4b3f2
commit 9669c80f29
1 changed files with 9 additions and 6 deletions

View File

@ -1392,12 +1392,15 @@ class LightningModule(
"""
# call .item() only once but store elements without graphs
running_train_loss = self.trainer.train_loop.running_loss.mean()
avg_training_loss = (
running_train_loss.cpu().item()
if running_train_loss is not None
else float("NaN")
)
tqdm_dict = {"loss": "{:.3g}".format(avg_training_loss)}
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
elif self.trainer.train_loop.automatic_optimization:
avg_training_loss = float('NaN')
tqdm_dict = {}
if avg_training_loss is not None:
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
if self.trainer.truncated_bptt_steps is not None:
tqdm_dict["split_idx"] = self.trainer.split_idx