Fix on_train_batch_end signature and call in ProgressBarBase example (#8836)

This commit is contained in:
Stefan Wijnja 2021-08-12 13:24:12 +01:00 committed by GitHub
parent 24f0124ddd
commit c77cd518b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -71,14 +71,14 @@ class ProgressBarBase(Callback):
class LitProgressBar(ProgressBarBase):
def __init__(self):
super().__init__() # don't forget this :)
super().__init__() # important :-)
self.enable = True
def disable(self):
self.enable = False
def on_train_batch_end(self, trainer, pl_module, outputs):
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # important :-)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')