From 77fb425dd4588534504b99f0b562b3cde8b44dbd Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 10 Dec 2020 08:38:14 +0100 Subject: [PATCH] update usage of deprecated profiler (#5010) * drop deprecated profiler * lut Co-authored-by: Roger Shieh --- pl_examples/domain_templates/imagenet.py | 2 +- .../trainer/connectors/profiler_connector.py | 19 +++++++++++-------- tests/trainer/test_trainer.py | 4 ++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index f02b6dc095..b7116547d3 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -237,7 +237,7 @@ def run_cli(): help='seed for initializing training.') parser = ImageNetLightningModel.add_model_specific_args(parent_parser) parser.set_defaults( - profiler=True, + profiler="simple", deterministic=True, max_epochs=90, ) diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 0f6686f1f8..3ecc168231 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -18,6 +18,11 @@ from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, Simple from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +PROFILERS = { + "simple": SimpleProfiler, + "advanced": AdvancedProfiler, +} + class ProfilerConnector: @@ -28,9 +33,9 @@ class ProfilerConnector: if profiler and not isinstance(profiler, (bool, str, BaseProfiler)): # TODO: Update exception on removal of bool - raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` " - "are valid values for `Trainer`'s `profiler` parameter. " - f"Received {profiler} which is of type {type(profiler)}.") + raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`" + " are valid values for `Trainer`'s `profiler` parameter." + f" Received {profiler} which is of type {type(profiler)}.") if isinstance(profiler, bool): rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated" @@ -39,11 +44,9 @@ class ProfilerConnector: if profiler: profiler = SimpleProfiler() elif isinstance(profiler, str): - profiler = profiler.lower() - if profiler == "simple": - profiler = SimpleProfiler() - elif profiler == "advanced": - profiler = AdvancedProfiler() + if profiler.lower() in PROFILERS: + profiler_class = PROFILERS[profiler.lower()] + profiler = profiler_class() else: raise ValueError("When passing string value for the `profiler` parameter of" " `Trainer`, it can only be 'simple' or 'advanced'") diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a25067b136..9b29d6ec2b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1476,6 +1476,6 @@ def test_trainer_profiler_incorrect_str_arg(): )) def test_trainer_profiler_incorrect_arg_type(profiler): with pytest.raises(MisconfigurationException, - match=r"Only None, bool, str and subclasses of `BaseProfiler` " - r"are valid values for `Trainer`'s `profiler` parameter. *"): + match=r"Only None, bool, str and subclasses of `BaseProfiler`" + r" are valid values for `Trainer`'s `profiler` parameter. *"): Trainer(profiler=profiler)