fix depreated call (#1596)

* fix parity

* update deprecated call
This commit is contained in:
Jirka Borovec 2020-04-24 20:45:43 +02:00 committed by GitHub
parent f07176da9b
commit 570b2c7aeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 8 additions and 8 deletions

View File

@ -140,7 +140,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
model = MODEL()
trainer = Trainer(
max_epochs=num_epochs,
show_progress_bar=False,
progress_bar_refresh_rate=0,
weights_summary=None,
gpus=1,
early_stop_callback=False,

View File

@ -141,7 +141,7 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
model = MODEL()
trainer = Trainer(
max_epochs=num_epochs,
show_progress_bar=False,
progress_bar_refresh_rate=0,
weights_summary=None,
gpus=1,
early_stop_callback=False,

View File

@ -114,7 +114,7 @@ class TrainerLRFinderMixin(ABC):
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
# Use special lr logger callback
self.callbacks = [_LRCallback(num_training, show_progress_bar=True)]
self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)]
# No logging
self.logger = None
@ -310,19 +310,19 @@ class _LRCallback(Callback):
""" Special callback used by the learning rate finder. This callbacks log
the learning rate before each batch and log the corresponding loss after
each batch. """
def __init__(self, num_training: int, show_progress_bar: bool = False, beta: float = 0.98):
def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98):
self.num_training = num_training
self.beta = beta
self.losses = []
self.lrs = []
self.avg_loss = 0.0
self.best_loss = 0.0
self.show_progress_bar = show_progress_bar
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None
def on_batch_start(self, trainer, pl_module):
""" Called before each training batch, logs the lr that will be used """
if self.show_progress_bar and self.progress_bar is None:
if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

View File

@ -54,7 +54,7 @@ def test_multi_cpu_model_ddp(tmpdir):
model, hparams = tutils.get_default_model()
trainer_options = dict(
default_root_dir=tmpdir,
show_progress_bar=False,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,

View File

@ -539,7 +539,7 @@ def test_disabled_validation():
model = CurrentModel(hparams)
trainer_options = dict(
show_progress_bar=False,
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.0,