Remove deprecated `progress_bar_refresh_rate` from Trainer constructor (#12514)

* Remove progress_bar_refresh_rate from Trainer constructor
* changelog
This commit is contained in:
Danielle Pintz 2022-03-30 11:47:55 -04:00 committed by GitHub
parent 32e68c5e79
commit 1acb0dcbf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 104 deletions

View File

@ -57,7 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
-
- Removed the deprecated `progress_bar_refresh_rate` argument from the `Trainer` constructor ([#12514](https://github.com/PyTorchLightning/pytorch-lightning/pull/12514))
-

View File

@ -1306,37 +1306,6 @@ See the :doc:`profiler documentation <../advanced/profiler>`. for more details.
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
trainer = Trainer(profiler="advanced")
progress_bar_refresh_rate
^^^^^^^^^^^^^^^^^^^^^^^^^
.. warning:: ``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
pass ``enable_progress_bar = False`` to the Trainer.
.. raw:: html
<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/progress_bar%E2%80%A8_refresh_rate.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/progress_bar_refresh_rate.mp4"></video>
|
How often to refresh progress bar (in steps).
.. testcode::
# default used by the Trainer
trainer = Trainer(progress_bar_refresh_rate=1)
# disable progress bar
trainer = Trainer(progress_bar_refresh_rate=0)
Note:
- In Google Colab notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates.
Lightning will set it to 20 in these environments if the user does not provide a value.
- This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`.
enable_progress_bar
^^^^^^^^^^^^^^^^^^^

View File

@ -90,10 +90,7 @@ class TQDMProgressBar(ProgressBarBase):
Args:
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer`
uses this implementation of the progress bar and sets the refresh rate to the value provided to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
Set it to ``0`` to disable the display.
process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines.
This is useful when you have progress bars defined elsewhere and want to show all of them
together. This corresponds to

View File

@ -41,7 +41,6 @@ class CallbackConnector:
checkpoint_callback: Optional[bool],
enable_checkpointing: bool,
enable_progress_bar: bool,
progress_bar_refresh_rate: Optional[int],
process_position: int,
default_root_dir: Optional[str],
weights_save_path: Optional[str],
@ -92,15 +91,7 @@ class CallbackConnector:
" `process_position` directly to the Trainer's `callbacks` argument instead."
)
if progress_bar_refresh_rate is not None:
rank_zero_deprecation(
f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with"
" `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress"
" bar pass `enable_progress_bar = False` to the Trainer."
)
self._configure_progress_bar(progress_bar_refresh_rate, process_position, enable_progress_bar)
self._configure_progress_bar(process_position, enable_progress_bar)
# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary, weights_summary)
@ -220,9 +211,7 @@ class CallbackConnector:
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks
def _configure_progress_bar(
self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True
) -> None:
def _configure_progress_bar(self, process_position: int = 0, enable_progress_bar: bool = True) -> None:
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
@ -243,14 +232,9 @@ class CallbackConnector:
f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list."
)
# Return early if the user intends to disable the progress bar callback
if refresh_rate == 0 or not enable_progress_bar:
return
if refresh_rate is None:
refresh_rate = 1
progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
if enable_progress_bar:
progress_bar_callback = TQDMProgressBar(process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None:
if max_time is None:

View File

@ -146,7 +146,6 @@ class Trainer(
tpu_cores: Optional[Union[List[int], str, int]] = None,
ipus: Optional[int] = None,
log_gpu_memory: Optional[str] = None, # TODO: Remove in 1.7
progress_bar_refresh_rate: Optional[int] = None, # TODO: remove in v1.7
enable_progress_bar: bool = True,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
@ -331,16 +330,6 @@ class Trainer(
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``process_position``
directly to the Trainer's ``callbacks`` argument instead.
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).
.. deprecated:: v1.5
``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
pass ``enable_progress_bar = False`` to the Trainer.
enable_progress_bar: Whether to enable to progress bar by default.
Default: ``False``.
@ -546,7 +535,6 @@ class Trainer(
checkpoint_callback,
enable_checkpointing,
enable_progress_bar,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,

View File

@ -71,33 +71,25 @@ class MockTqdm(Tqdm):
@pytest.mark.parametrize(
"kwargs",
"pbar",
[
# won't print but is still set
{"callbacks": TQDMProgressBar(refresh_rate=0)},
{"callbacks": TQDMProgressBar()},
{"progress_bar_refresh_rate": 1},
TQDMProgressBar(refresh_rate=0),
TQDMProgressBar(),
],
)
def test_tqdm_progress_bar_on(tmpdir, kwargs):
def test_tqdm_progress_bar_on(tmpdir, pbar):
"""Test different ways the progress bar can be turned on."""
if "progress_bar_refresh_rate" in kwargs:
with pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated"):
trainer = Trainer(default_root_dir=tmpdir, **kwargs)
else:
trainer = Trainer(default_root_dir=tmpdir, **kwargs)
trainer = Trainer(default_root_dir=tmpdir, callbacks=pbar)
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
assert len(progress_bars) == 1
assert progress_bars[0] is trainer.progress_bar_callback
@pytest.mark.parametrize("kwargs", [{"enable_progress_bar": False}, {"progress_bar_refresh_rate": 0}])
def test_tqdm_progress_bar_off(tmpdir, kwargs):
"""Test different ways the progress bar can be turned off."""
if "progress_bar_refresh_rate" in kwargs:
pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated").__enter__()
trainer = Trainer(default_root_dir=tmpdir, **kwargs)
def test_tqdm_progress_bar_off(tmpdir):
"""Test turning the progress bar off."""
trainer = Trainer(default_root_dir=tmpdir, enable_progress_bar=False)
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
assert not len(progress_bars)
@ -263,15 +255,13 @@ def test_tqdm_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
self.test_batches_seen += 1
pbar = CurrentProgressBar(refresh_rate=refresh_rate)
with pytest.deprecated_call(match=r"progress_bar_refresh_rate=101\)` is deprecated"):
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[pbar],
progress_bar_refresh_rate=101, # should not matter if custom callback provided
limit_train_batches=1.0,
num_sanity_val_steps=2,
max_epochs=3,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[pbar],
limit_train_batches=1.0,
num_sanity_val_steps=2,
max_epochs=3,
)
assert trainer.progress_bar_callback.refresh_rate == refresh_rate
trainer.fit(model)
@ -350,10 +340,6 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19))
assert trainer.progress_bar_callback.refresh_rate == 19
with pytest.deprecated_call(match=r"progress_bar_refresh_rate=19\)` is deprecated"):
trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19)
assert trainer.progress_bar_callback.refresh_rate == 19
@pytest.mark.parametrize(
"train_batches,val_batches,refresh_rate,train_updates,val_updates",

View File

@ -112,7 +112,6 @@ def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=None,
fast_dev_run=True,
)
test_model = TestModel()
@ -257,11 +256,6 @@ def test_v1_7_0_deprecate_add_get_queue(tmpdir):
trainer.fit(model)
def test_v1_7_0_progress_bar_refresh_rate_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(progress_bar_refresh_rate=1\)` is deprecated in v1.5"):
_ = Trainer(progress_bar_refresh_rate=1)
def test_v1_7_0_lightning_logger_base_close(tmpdir):
logger = CustomLogger()
with pytest.deprecated_call(