diff --git a/CHANGELOG.md b/CHANGELOG.md index b56321765b..f029aeadb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185)) +- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) + + ### Deprecated - `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index ca6c574c0c..c2c972b730 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -1247,8 +1247,6 @@ progress_bar_refresh_rate | How often to refresh progress bar (in steps). -In notebooks, faster refresh rates (lower number) is known to crash them -because of their screen refresh rates, so raise it to 50 or more. .. testcode:: @@ -1259,7 +1257,9 @@ because of their screen refresh rates, so raise it to 50 or more. trainer = Trainer(progress_bar_refresh_rate=0) Note: - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + - In Google Colab notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates. + Lightning will set it to 20 in these environments if the user does not provide a value. + - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. reload_dataloaders_every_epoch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 72a0641a08..61e59d2fb9 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -75,14 +75,11 @@ class CallbackConnector: if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min')) - def configure_progress_bar(self, refresh_rate=1, process_position=0): - # smaller refresh rate on colab causes crashes, warn user about this - if os.getenv('COLAB_GPU') and refresh_rate < 20: - rank_zero_warn( - "You have set progress_bar_refresh_rate < 20 on Google Colab. This" - " may crash. Consider using progress_bar_refresh_rate >= 20 in Trainer.", - UserWarning - ) + def configure_progress_bar(self, refresh_rate=None, process_position=0): + if os.getenv('COLAB_GPU') and refresh_rate is None: + # smaller refresh rate on colab causes crashes, choose a higher value + refresh_rate = 20 + refresh_rate = 1 if refresh_rate is None else refresh_rate progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 59d745a2be..0a2362a438 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -95,7 +95,7 @@ class Trainer( auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, - progress_bar_refresh_rate: int = 1, + progress_bar_refresh_rate: Optional[int] = None, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, @@ -219,7 +219,8 @@ class Trainer( process_position: orders the progress bar when running multiple models on same machine. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. - Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. + Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means + a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.). profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool value is deprecated in v1.1 and will be removed in v1.3. diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8840dae54a..0320a3dbd0 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -24,6 +24,7 @@ from tests.base import BoringModel, EvalModelTemplate @pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], None), ([], 1), ([], 2), ([ProgressBar(refresh_rate=1)], 0), @@ -245,14 +246,25 @@ def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): assert trainer.progress_bar_callback.val_progress_bar_total == expected -@mock.patch.dict(os.environ, {'COLAB_GPU': '1'}) -def test_progress_bar_warning_on_colab(tmpdir): - with pytest.warns(UserWarning, match='on Google Colab. This may crash.'): - trainer = Trainer( - default_root_dir=tmpdir, - progress_bar_refresh_rate=19, - ) +def test_progress_bar_default_value(tmpdir): + """ Test that a value of None defaults to refresh rate 1. """ + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.progress_bar_callback.refresh_rate == 1 + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) + assert trainer.progress_bar_callback.refresh_rate == 1 + + +@mock.patch.dict(os.environ, {'COLAB_GPU': '1'}) +def test_progress_bar_value_on_colab(tmpdir): + """ Test that Trainer will override the default in Google COLAB. """ + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.progress_bar_callback.refresh_rate == 20 + + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) + assert trainer.progress_bar_callback.refresh_rate == 20 + + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19) assert trainer.progress_bar_callback.refresh_rate == 19