parent
f07176da9b
commit
570b2c7aeb
|
@ -140,7 +140,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
model = MODEL()
|
||||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
show_progress_bar=False,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
gpus=1,
|
||||
early_stop_callback=False,
|
||||
|
|
|
@ -141,7 +141,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
model = MODEL()
|
||||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
show_progress_bar=False,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
gpus=1,
|
||||
early_stop_callback=False,
|
||||
|
|
|
@ -114,7 +114,7 @@ class TrainerLRFinderMixin(ABC):
|
|||
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
|
||||
|
||||
# Use special lr logger callback
|
||||
self.callbacks = [_LRCallback(num_training, show_progress_bar=True)]
|
||||
self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)]
|
||||
|
||||
# No logging
|
||||
self.logger = None
|
||||
|
@ -310,19 +310,19 @@ class _LRCallback(Callback):
|
|||
""" Special callback used by the learning rate finder. This callbacks log
|
||||
the learning rate before each batch and log the corresponding loss after
|
||||
each batch. """
|
||||
def __init__(self, num_training: int, show_progress_bar: bool = False, beta: float = 0.98):
|
||||
def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98):
|
||||
self.num_training = num_training
|
||||
self.beta = beta
|
||||
self.losses = []
|
||||
self.lrs = []
|
||||
self.avg_loss = 0.0
|
||||
self.best_loss = 0.0
|
||||
self.show_progress_bar = show_progress_bar
|
||||
self.progress_bar_refresh_rate = progress_bar_refresh_rate
|
||||
self.progress_bar = None
|
||||
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
""" Called before each training batch, logs the lr that will be used """
|
||||
if self.show_progress_bar and self.progress_bar is None:
|
||||
if self.progress_bar_refresh_rate and self.progress_bar is None:
|
||||
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)
|
||||
|
||||
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
|
||||
|
|
|
@ -54,7 +54,7 @@ def test_multi_cpu_model_ddp(tmpdir):
|
|||
model, hparams = tutils.get_default_model()
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
show_progress_bar=False,
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
|
|
|
@ -539,7 +539,7 @@ def test_disabled_validation():
|
|||
model = CurrentModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.0,
|
||||
|
|
Loading…
Reference in New Issue