Fix on_train_batch_end signature and call in ProgressBarBase example (#8836)
This commit is contained in:
parent
24f0124ddd
commit
c77cd518b5
|
@ -71,14 +71,14 @@ class ProgressBarBase(Callback):
|
|||
class LitProgressBar(ProgressBarBase):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__() # don't forget this :)
|
||||
super().__init__() # important :-)
|
||||
self.enable = True
|
||||
|
||||
def disable(self):
|
||||
self.enable = False
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs):
|
||||
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # important :-)
|
||||
percent = (self.train_batch_idx / self.total_train_batches) * 100
|
||||
sys.stdout.flush()
|
||||
sys.stdout.write(f'{percent:.01f} percent complete \r')
|
||||
|
|
Loading…
Reference in New Issue