diff --git a/CHANGELOG.md b/CHANGELOG.md index 08e2e93b93..f8ed74235c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344)) +- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) + + ### Changed @@ -48,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237)) +- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) + + ### Removed diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index fc684d143e..fe14bb5751 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -22,12 +22,12 @@ PyTorch Lightning supports profiling standard actions in the training loop out o Enable simple profiling ----------------------- -If you only wish to profile the standard actions, you can set `profiler=True` when constructing -your `Trainer` object. +If you only wish to profile the standard actions, you can set `profiler="simple"` +when constructing your `Trainer` object. .. code-block:: python - trainer = Trainer(..., profiler=True) + trainer = Trainer(..., profiler="simple") The profiler's results will be printed at the completion of a training `fit()`. @@ -59,6 +59,10 @@ This option uses Python's cProfiler_ to provide a report of time spent on *each* .. code-block:: python + trainer = Trainer(..., profiler="advanced") + + or + profiler = AdvancedProfiler() trainer = Trainer(..., profiler=profiler) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 6109c21b82..a4bf2969f4 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -1199,14 +1199,11 @@ See the :ref:`profiler documentation `. for more details. # default used by the Trainer trainer = Trainer(profiler=None) - # to profile standard training events - trainer = Trainer(profiler=True) + # to profile standard training events, equivalent to `profiler=SimpleProfiler()` + trainer = Trainer(profiler="simple") - # equivalent to profiler=True - trainer = Trainer(profiler=SimpleProfiler()) - - # advanced profiler for function-level stats - trainer = Trainer(profiler=AdvancedProfiler()) + # advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()` + trainer = Trainer(profiler="advanced") progress_bar_refresh_rate ^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 17aed23ab5..0f6686f1f8 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -11,7 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License -from pytorch_lightning.profiler import PassThroughProfiler, SimpleProfiler + +from typing import Union + +from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler, AdvancedProfiler +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class ProfilerConnector: @@ -19,8 +24,27 @@ class ProfilerConnector: def __init__(self, trainer): self.trainer = trainer - def on_trainer_init(self, profiler): - # configure profiler - if profiler is True: - profiler = SimpleProfiler() + def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]): + + if profiler and not isinstance(profiler, (bool, str, BaseProfiler)): + # TODO: Update exception on removal of bool + raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` " + "are valid values for `Trainer`'s `profiler` parameter. " + f"Received {profiler} which is of type {type(profiler)}.") + + if isinstance(profiler, bool): + rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated" + " and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", + DeprecationWarning) + if profiler: + profiler = SimpleProfiler() + elif isinstance(profiler, str): + profiler = profiler.lower() + if profiler == "simple": + profiler = SimpleProfiler() + elif profiler == "advanced": + profiler = AdvancedProfiler() + else: + raise ValueError("When passing string value for the `profiler` parameter of" + " `Trainer`, it can only be 'simple' or 'advanced'") self.trainer.profiler = profiler or PassThroughProfiler() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 44250ae905..337eb4c4ed 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -120,7 +120,7 @@ class Trainer( num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, - profiler: Optional[Union[BaseProfiler, bool]] = None, + profiler: Optional[Union[BaseProfiler, bool, str]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, @@ -212,7 +212,8 @@ class Trainer( progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. - profiler: To profile individual steps during training and assist in identifying bottlenecks. + profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool + value is deprecated in v1.1 and will be removed in v1.3. overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0 diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index f3cf2e5f1b..bbb89ad09a 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -174,8 +174,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: # if the only arg type is bool if len(arg_types) == 1: use_type = parsing.str_to_bool - # if only two args (str, bool) - elif len(arg_types) == 2 and set(arg_types) == {str, bool}: + elif str in arg_types: use_type = parsing.str_to_bool_or_str else: # filter out the bool as we need to use more general diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index c8a7b1d270..60f13383d3 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,13 +1,17 @@ """Test deprecated functionality which will be removed in vX.Y.Z""" +from argparse import ArgumentParser import pytest import sys +from unittest import mock import torch from tests.base import EvalModelTemplate from pytorch_lightning.metrics.functional.classification import auc +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -22,6 +26,37 @@ def test_tbd_remove_in_v1_2_0(): checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.') +# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py +@pytest.mark.parametrize(['profiler', 'expected'], [ + (True, SimpleProfiler), + (False, PassThroughProfiler), +]) +def test_trainer_profiler_remove_in_v1_3_0(profiler, expected): + with pytest.deprecated_call(match='will be removed in v1.3'): + trainer = Trainer(profiler=profiler) + assert isinstance(trainer.profiler, expected) + + +@pytest.mark.parametrize( + ['cli_args', 'expected_parsed_arg', 'expected_profiler'], + [ + ('--profiler', True, SimpleProfiler), + ('--profiler True', True, SimpleProfiler), + ('--profiler False', False, PassThroughProfiler), + ], +) +def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, expected_profiler): + cli_args = cli_args.split(' ') + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + + assert getattr(args, "profiler") == expected_parsed_arg + trainer = Trainer.from_argparse_args(args) + assert isinstance(trainer.profiler, expected_profiler) + + def _soft_unimport_module(str_module): # once the module is imported e.g with parsing with pytest it lives in memory if str_module in sys.modules: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4841d1461f..35257e2870 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -32,6 +32,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -1408,3 +1409,32 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l trainer.fit(model) expected_calls = [call(metrics=ANY, step=s) for s in range(log_interval - 1, max_steps, log_interval)] log_metrics_mock.assert_has_calls(expected_calls) + + +@pytest.mark.parametrize(['profiler', 'expected'], [ + (None, PassThroughProfiler), + (SimpleProfiler(), SimpleProfiler), + (AdvancedProfiler(), AdvancedProfiler), + ('simple', SimpleProfiler), + ('Simple', SimpleProfiler), + ('advanced', AdvancedProfiler), +]) +def test_trainer_profiler_correct_args(profiler, expected): + kwargs = {'profiler': profiler} if profiler is not None else {} + trainer = Trainer(**kwargs) + assert isinstance(trainer.profiler, expected) + + +def test_trainer_profiler_incorrect_str_arg(): + with pytest.raises(ValueError, match=r".*can only be 'simple' or 'advanced'"): + Trainer(profiler="unknown_profiler") + + +@pytest.mark.parametrize('profiler', ( + 42, [42], {"a": 42}, torch.tensor(42), Trainer(), +)) +def test_trainer_profiler_incorrect_arg_type(profiler): + with pytest.raises(MisconfigurationException, + match=r"Only None, bool, str and subclasses of `BaseProfiler` " + r"are valid values for `Trainer`'s `profiler` parameter. *"): + Trainer(profiler=profiler)