diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 5edc5ad0c3..16f2288229 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -281,11 +281,11 @@ class Trainer(TrainerIO): self.__run_validation() # when batch should be saved - if (batch_nb + 1) % self.log_save_interval == 0: + if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch: self.experiment.save() # when metrics should be logged - if batch_nb % self.add_log_row_interval == 0: + if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch: # count items in memory # nb_params, nb_tensors = count_mem_items()