make progress bar match internal epoch counter (#3061)
* fix 3018, 3032 * changed progress bar for 3032
This commit is contained in:
parent
7933a1121c
commit
10150fccb0
|
@ -113,7 +113,7 @@ class ProgressBarBase(Callback):
|
|||
"""
|
||||
total_val_batches = 0
|
||||
if not self.trainer.disable_validation:
|
||||
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
|
||||
is_val_epoch = (self.trainer.current_epoch) % self.trainer.check_val_every_n_epoch == 0
|
||||
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
|
||||
return total_val_batches
|
||||
|
||||
|
@ -330,7 +330,7 @@ class ProgressBar(ProgressBarBase):
|
|||
total_batches = total_train_batches + total_val_batches
|
||||
if not self.main_progress_bar.disable:
|
||||
self.main_progress_bar.reset(convert_inf(total_batches))
|
||||
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')
|
||||
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
super().on_train_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
|
||||
|
|
Loading…
Reference in New Issue