[bugfix] remove nan loss in manual optimization (#5121)
* remove nan loss whe missing * Update pytorch_lightning/core/lightning.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
13bbf4b3f2
commit
9669c80f29
|
@ -1392,12 +1392,15 @@ class LightningModule(
|
|||
"""
|
||||
# call .item() only once but store elements without graphs
|
||||
running_train_loss = self.trainer.train_loop.running_loss.mean()
|
||||
avg_training_loss = (
|
||||
running_train_loss.cpu().item()
|
||||
if running_train_loss is not None
|
||||
else float("NaN")
|
||||
)
|
||||
tqdm_dict = {"loss": "{:.3g}".format(avg_training_loss)}
|
||||
avg_training_loss = None
|
||||
if running_train_loss is not None:
|
||||
avg_training_loss = running_train_loss.cpu().item()
|
||||
elif self.trainer.train_loop.automatic_optimization:
|
||||
avg_training_loss = float('NaN')
|
||||
|
||||
tqdm_dict = {}
|
||||
if avg_training_loss is not None:
|
||||
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
|
||||
|
||||
if self.trainer.truncated_bptt_steps is not None:
|
||||
tqdm_dict["split_idx"] = self.trainer.split_idx
|
||||
|
|
Loading…
Reference in New Issue